From 9426b603082905d0af8a07bdac866bc1d9c37cba Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Fri, 5 May 2023 17:42:41 -0700 Subject: [PATCH] TUN-7227: Migrate to devincarr/quic-go The lucas-clemente/quic-go package moved namespaces and our branch went stale, this new fork provides support for the new quic-go repo and applies the max datagram frame size change. Until the max datagram frame size support gets upstreamed into quic-go, this can be used to unblock go 1.20 support as the old lucas-clemente/quic-go will not get go 1.20 support. --- connection/quic.go | 5 +- connection/quic_test.go | 29 +- datagramsession/session.go | 2 +- go.mod | 30 +- go.sum | 149 +- quic/conversion.go | 8 +- quic/datagram.go | 2 +- quic/datagram_test.go | 8 +- quic/datagramv2.go | 2 +- quic/metrics.go | 2 +- quic/param_windows.go | 6 +- quic/safe_stream.go | 2 +- quic/safe_stream_test.go | 6 +- quic/tracing.go | 28 +- supervisor/supervisor.go | 2 +- supervisor/tunnel.go | 3 +- supervisor/tunnel_test.go | 2 +- vendor/github.com/cheekybits/genny/.gitignore | 26 - .../github.com/cheekybits/genny/.travis.yml | 6 - vendor/github.com/cheekybits/genny/LICENSE | 22 - vendor/github.com/cheekybits/genny/README.md | 245 --- vendor/github.com/cheekybits/genny/doc.go | 2 - .../cheekybits/genny/generic/doc.go | 2 - .../cheekybits/genny/generic/generic.go | 13 - vendor/github.com/cheekybits/genny/main.go | 154 -- .../cheekybits/genny/out/lazy_file.go | 38 - .../cheekybits/genny/parse/builtins.go | 41 - .../github.com/cheekybits/genny/parse/doc.go | 14 - .../cheekybits/genny/parse/errors.go | 47 - .../cheekybits/genny/parse/parse.go | 298 --- .../cheekybits/genny/parse/typesets.go | 89 - vendor/github.com/golang/mock/AUTHORS | 12 + vendor/github.com/golang/mock/CONTRIBUTORS | 37 + vendor/github.com/golang/mock/LICENSE | 202 ++ .../github.com/golang/mock/mockgen/mockgen.go | 701 +++++++ .../golang/mock/mockgen/model/model.go | 495 +++++ .../github.com/golang/mock/mockgen/parse.go | 644 ++++++ .../github.com/golang/mock/mockgen/reflect.go | 256 +++ .../golang/mock/mockgen/version.1.11.go | 26 + .../golang/mock/mockgen/version.1.12.go | 35 + vendor/github.com/google/pprof/AUTHORS | 7 + vendor/github.com/google/pprof/CONTRIBUTORS | 16 + vendor/github.com/google/pprof/LICENSE | 202 ++ .../github.com/google/pprof/profile/encode.go | 567 +++++ .../github.com/google/pprof/profile/filter.go | 270 +++ .../github.com/google/pprof/profile/index.go | 64 + .../pprof/profile/legacy_java_profile.go | 315 +++ .../google/pprof/profile/legacy_profile.go | 1225 +++++++++++ .../github.com/google/pprof/profile/merge.go | 481 +++++ .../google/pprof/profile/profile.go | 805 ++++++++ .../github.com/google/pprof/profile/proto.go | 370 ++++ .../github.com/google/pprof/profile/prune.go | 178 ++ .../lucas-clemente/quic-go/README.md | 61 - .../lucas-clemente/quic-go/client.go | 338 --- .../lucas-clemente/quic-go/closed_conn.go | 112 - .../lucas-clemente/quic-go/datagram_queue.go | 94 - .../quic-go/internal/ackhandler/ackhandler.go | 21 - .../quic-go/internal/ackhandler/frame.go | 9 - .../quic-go/internal/ackhandler/gen.go | 3 - .../quic-go/internal/ackhandler/mockgen.go | 3 - .../internal/ackhandler/packet_linkedlist.go | 217 -- .../quic-go/internal/handshake/mockgen.go | 3 - .../quic-go/internal/handshake/retry.go | 62 - .../internal/protocol/connection_id.go | 81 - .../quic-go/internal/qtls/go116.go | 100 - .../quic-go/internal/qtls/go118.go | 100 - .../quic-go/internal/qtls/go120.go | 6 - .../quic-go/internal/qtls/go_oldversion.go | 7 - .../quic-go/internal/utils/atomic_bool.go | 22 - .../quic-go/internal/utils/gen.go | 5 - .../quic-go/internal/utils/minmax.go | 170 -- .../internal/utils/new_connection_id.go | 12 - .../utils/newconnectionid_linkedlist.go | 217 -- .../quic-go/internal/utils/packet_interval.go | 9 - .../utils/packetinterval_linkedlist.go | 217 -- .../internal/utils/streamframe_interval.go | 9 - .../quic-go/internal/wire/extended_header.go | 249 --- .../quic-go/internal/wire/frame_parser.go | 143 -- .../internal/wire/handshake_done_frame.go | 28 - .../quic-go/internal/wire/ping_frame.go | 27 - .../lucas-clemente/quic-go/logging/mockgen.go | 4 - .../lucas-clemente/quic-go/mockgen.go | 27 - .../lucas-clemente/quic-go/mockgen_private.sh | 49 - .../lucas-clemente/quic-go/multiplexer.go | 107 - .../quic-go/packet_handler_map.go | 489 ----- .../lucas-clemente/quic-go/packet_packer.go | 894 -------- .../lucas-clemente/quic-go/server.go | 670 ------ .../quic-go/streams_map_generic_helper.go | 18 - .../quic-go/streams_map_incoming_bidi.go | 192 -- .../quic-go/streams_map_incoming_generic.go | 190 -- .../quic-go/streams_map_outgoing_bidi.go | 226 -- .../quic-go/streams_map_outgoing_generic.go | 224 -- .../lucas-clemente/quic-go/tools.go | 9 - .../marten-seemann/qtls-go1-16/README.md | 6 - .../marten-seemann/qtls-go1-16/auth.go | 289 --- .../qtls-go1-16/cipher_suites.go | 532 ----- .../marten-seemann/qtls-go1-16/common.go | 1576 -------------- .../marten-seemann/qtls-go1-16/common_js.go | 12 - .../marten-seemann/qtls-go1-16/common_nojs.go | 20 - .../marten-seemann/qtls-go1-16/conn.go | 1536 -------------- .../qtls-go1-16/handshake_client.go | 1105 ---------- .../qtls-go1-16/handshake_client_tls13.go | 740 ------- .../qtls-go1-16/handshake_messages.go | 1832 ----------------- .../qtls-go1-16/handshake_server.go | 878 -------- .../qtls-go1-16/handshake_server_tls13.go | 912 -------- .../qtls-go1-16/key_agreement.go | 338 --- .../qtls-go1-16/key_schedule.go | 199 -- .../marten-seemann/qtls-go1-16/ticket.go | 259 --- .../marten-seemann/qtls-go1-16/tls.go | 393 ---- .../marten-seemann/qtls-go1-17/README.md | 6 - .../marten-seemann/qtls-go1-17/auth.go | 289 --- .../qtls-go1-17/cipher_suites.go | 691 ------- .../marten-seemann/qtls-go1-17/common.go | 1485 ------------- .../marten-seemann/qtls-go1-17/conn.go | 1601 -------------- .../qtls-go1-17/handshake_client.go | 1111 ---------- .../qtls-go1-17/handshake_client_tls13.go | 732 ------- .../qtls-go1-17/handshake_messages.go | 1832 ----------------- .../qtls-go1-17/handshake_server.go | 905 -------- .../qtls-go1-17/handshake_server_tls13.go | 896 -------- .../qtls-go1-17/key_agreement.go | 357 ---- .../marten-seemann/qtls-go1-18/README.md | 6 - .../marten-seemann/qtls-go1-18/alert.go | 102 - .../marten-seemann/qtls-go1-18/prf.go | 283 --- .../marten-seemann/qtls-go1-18/unsafe.go | 96 - .../marten-seemann/qtls-go1-19/LICENSE | 27 - .../marten-seemann/qtls-go1-19/README.md | 6 - .../marten-seemann/qtls-go1-19/alert.go | 102 - .../marten-seemann/qtls-go1-19/cpu.go | 22 - .../marten-seemann/qtls-go1-19/cpu_other.go | 12 - .../qtls-go1-19/key_schedule.go | 199 -- .../marten-seemann/qtls-go1-19/prf.go | 283 --- .../marten-seemann/qtls-go1-19/ticket.go | 274 --- .../marten-seemann/qtls-go1-19/tls.go | 362 ---- .../marten-seemann/qtls-go1-19/unsafe.go | 96 - vendor/github.com/nxadm/tail/.gitignore | 3 - vendor/github.com/nxadm/tail/CHANGES.md | 56 - vendor/github.com/nxadm/tail/Dockerfile | 19 - vendor/github.com/nxadm/tail/LICENSE | 21 - vendor/github.com/nxadm/tail/README.md | 44 - .../github.com/nxadm/tail/ratelimiter/Licence | 7 - .../nxadm/tail/ratelimiter/leakybucket.go | 97 - .../nxadm/tail/ratelimiter/memory.go | 60 - .../nxadm/tail/ratelimiter/storage.go | 6 - vendor/github.com/nxadm/tail/tail.go | 455 ---- vendor/github.com/nxadm/tail/tail_posix.go | 17 - vendor/github.com/nxadm/tail/tail_windows.go | 19 - vendor/github.com/nxadm/tail/util/util.go | 49 - .../nxadm/tail/watch/filechanges.go | 37 - vendor/github.com/nxadm/tail/watch/inotify.go | 136 -- .../nxadm/tail/watch/inotify_tracker.go | 249 --- vendor/github.com/nxadm/tail/watch/polling.go | 119 -- vendor/github.com/nxadm/tail/watch/watch.go | 21 - .../github.com/nxadm/tail/winfile/winfile.go | 93 - .../github.com/onsi/ginkgo/config/config.go | 232 --- .../onsi/ginkgo/ginkgo/bootstrap_command.go | 201 -- .../onsi/ginkgo/ginkgo/build_command.go | 66 - .../ginkgo/ginkgo/convert/ginkgo_ast_nodes.go | 123 -- .../onsi/ginkgo/ginkgo/convert/import.go | 90 - .../ginkgo/ginkgo/convert/package_rewriter.go | 128 -- .../onsi/ginkgo/ginkgo/convert/test_finder.go | 56 - .../ginkgo/convert/testfile_rewriter.go | 162 -- .../ginkgo/convert/testing_t_rewriter.go | 130 -- .../onsi/ginkgo/ginkgo/convert_command.go | 51 - .../onsi/ginkgo/ginkgo/help_command.go | 33 - .../interrupthandler/interrupt_handler.go | 52 - vendor/github.com/onsi/ginkgo/ginkgo/main.go | 337 --- .../onsi/ginkgo/ginkgo/nodot/nodot.go | 196 -- .../onsi/ginkgo/ginkgo/nodot_command.go | 77 - .../onsi/ginkgo/ginkgo/notifications.go | 141 -- .../onsi/ginkgo/ginkgo/outline_command.go | 95 - .../onsi/ginkgo/ginkgo/run_command.go | 316 --- .../run_watch_and_build_command_flags.go | 169 -- .../onsi/ginkgo/ginkgo/suite_runner.go | 173 -- .../ginkgo/ginkgo/testrunner/build_args.go | 7 - .../ginkgo/testrunner/build_args_old.go | 7 - .../ginkgo/ginkgo/testrunner/log_writer.go | 52 - .../ginkgo/ginkgo/testrunner/run_result.go | 27 - .../ginkgo/ginkgo/testrunner/test_runner.go | 554 ----- .../ginkgo/ginkgo/testsuite/test_suite.go | 115 -- .../ginkgo/testsuite/vendor_check_go15.go | 16 - .../ginkgo/testsuite/vendor_check_go16.go | 15 - .../onsi/ginkgo/ginkgo/version_command.go | 25 - .../onsi/ginkgo/ginkgo/watch_command.go | 175 -- .../internal/codelocation/code_location.go | 48 - .../internal/containernode/container_node.go | 151 -- .../onsi/ginkgo/internal/failer/failer.go | 92 - .../ginkgo/internal/leafnodes/benchmarker.go | 103 - .../ginkgo/internal/leafnodes/interfaces.go | 19 - .../onsi/ginkgo/internal/leafnodes/it_node.go | 47 - .../ginkgo/internal/leafnodes/measure_node.go | 62 - .../onsi/ginkgo/internal/leafnodes/runner.go | 117 -- .../ginkgo/internal/leafnodes/setup_nodes.go | 48 - .../ginkgo/internal/leafnodes/suite_nodes.go | 55 - .../synchronized_after_suite_node.go | 90 - .../synchronized_before_suite_node.go | 181 -- .../onsi/ginkgo/internal/remote/aggregator.go | 249 --- .../internal/remote/forwarding_reporter.go | 147 -- .../internal/remote/output_interceptor.go | 13 - .../remote/output_interceptor_unix.go | 82 - .../internal/remote/output_interceptor_win.go | 36 - .../onsi/ginkgo/internal/remote/server.go | 224 -- .../onsi/ginkgo/internal/spec/spec.go | 247 --- .../onsi/ginkgo/internal/spec/specs.go | 144 -- .../internal/spec_iterator/index_computer.go | 55 - .../spec_iterator/parallel_spec_iterator.go | 59 - .../spec_iterator/serial_spec_iterator.go | 45 - .../sharded_parallel_spec_iterator.go | 47 - .../internal/spec_iterator/spec_iterator.go | 20 - .../ginkgo/internal/writer/fake_writer.go | 36 - .../onsi/ginkgo/internal/writer/writer.go | 89 - .../onsi/ginkgo/reporters/default_reporter.go | 87 - .../onsi/ginkgo/reporters/fake_reporter.go | 59 - .../onsi/ginkgo/reporters/junit_reporter.go | 178 -- .../onsi/ginkgo/reporters/reporter.go | 15 - .../reporters/stenographer/console_logging.go | 64 - .../stenographer/fake_stenographer.go | 142 -- .../reporters/stenographer/stenographer.go | 572 ----- .../support/go-colorable/README.md | 43 - .../support/go-colorable/colorable_others.go | 24 - .../support/go-colorable/noncolorable.go | 57 - .../stenographer/support/go-isatty/LICENSE | 9 - .../stenographer/support/go-isatty/README.md | 37 - .../stenographer/support/go-isatty/doc.go | 2 - .../support/go-isatty/isatty_appengine.go | 9 - .../support/go-isatty/isatty_bsd.go | 18 - .../support/go-isatty/isatty_linux.go | 18 - .../support/go-isatty/isatty_solaris.go | 16 - .../support/go-isatty/isatty_windows.go | 19 - .../ginkgo/reporters/teamcity_reporter.go | 106 - .../onsi/ginkgo/types/code_location.go | 15 - .../onsi/ginkgo/types/synchronization.go | 30 - vendor/github.com/onsi/ginkgo/types/types.go | 174 -- .../github.com/onsi/ginkgo/{ => v2}/LICENSE | 0 .../onsi/ginkgo/v2/config/deprecated.go | 69 + .../formatter/colorable_others.go} | 20 + .../formatter}/colorable_windows.go | 74 +- .../ginkgo/{ => v2}/formatter/formatter.go | 15 +- .../ginkgo/v2/ginkgo/build/build_command.go | 63 + .../onsi/ginkgo/v2/ginkgo/command/abort.go | 61 + .../onsi/ginkgo/v2/ginkgo/command/command.go | 50 + .../onsi/ginkgo/v2/ginkgo/command/program.go | 182 ++ .../ginkgo/generators/boostrap_templates.go | 48 + .../v2/ginkgo/generators/bootstrap_command.go | 113 + .../ginkgo/generators}/generate_command.go | 189 +- .../ginkgo/generators/generate_templates.go | 41 + .../v2/ginkgo/generators/generators_common.go | 63 + .../onsi/ginkgo/v2/ginkgo/internal/compile.go | 152 ++ .../ginkgo/internal/profiles_and_reports.go | 237 +++ .../onsi/ginkgo/v2/ginkgo/internal/run.go | 348 ++++ .../ginkgo/v2/ginkgo/internal/test_suite.go | 283 +++ .../onsi/ginkgo/v2/ginkgo/internal/utils.go | 86 + .../v2/ginkgo/internal/verify_version.go | 54 + .../ginkgo/v2/ginkgo/labels/labels_command.go | 123 ++ .../github.com/onsi/ginkgo/v2/ginkgo/main.go | 58 + .../ginkgo/{ => v2}/ginkgo/outline/ginkgo.go | 39 +- .../ginkgo/{ => v2}/ginkgo/outline/import.go | 2 +- .../ginkgo/{ => v2}/ginkgo/outline/outline.go | 12 +- .../v2/ginkgo/outline/outline_command.go | 98 + .../onsi/ginkgo/v2/ginkgo/run/run_command.go | 232 +++ .../ginkgo/unfocus}/unfocus_command.go | 56 +- .../ginkgo/{ => v2}/ginkgo/watch/delta.go | 0 .../{ => v2}/ginkgo/watch/delta_tracker.go | 8 +- .../{ => v2}/ginkgo/watch/dependencies.go | 0 .../{ => v2}/ginkgo/watch/package_hash.go | 12 +- .../{ => v2}/ginkgo/watch/package_hashes.go | 0 .../ginkgo/{ => v2}/ginkgo/watch/suite.go | 6 +- .../ginkgo/v2/ginkgo/watch/watch_command.go | 192 ++ .../interrupt_handler/interrupt_handler.go | 162 ++ .../sigquit_swallower_unix.go | 3 +- .../sigquit_swallower_windows.go | 3 +- .../parallel_support/client_server.go | 70 + .../internal/parallel_support/http_client.go | 156 ++ .../internal/parallel_support/http_server.go | 223 ++ .../internal/parallel_support/rpc_client.go | 123 ++ .../internal/parallel_support/rpc_server.go | 75 + .../parallel_support/server_handler.go | 209 ++ .../ginkgo/v2/reporters/default_reporter.go | 642 ++++++ .../v2/reporters/deprecated_reporter.go | 149 ++ .../onsi/ginkgo/v2/reporters/json_report.go | 60 + .../onsi/ginkgo/v2/reporters/junit_report.go | 358 ++++ .../onsi/ginkgo/v2/reporters/reporter.go | 21 + .../ginkgo/v2/reporters/teamcity_report.go | 101 + .../onsi/ginkgo/v2/types/code_location.go | 91 + .../github.com/onsi/ginkgo/v2/types/config.go | 740 +++++++ .../onsi/ginkgo/v2/types/deprecated_types.go | 141 ++ .../{ => v2}/types/deprecation_support.go | 52 +- .../onsi/ginkgo/v2/types/enum_support.go | 43 + .../github.com/onsi/ginkgo/v2/types/errors.go | 621 ++++++ .../onsi/ginkgo/v2/types/file_filter.go | 106 + .../github.com/onsi/ginkgo/v2/types/flags.go | 489 +++++ .../onsi/ginkgo/v2/types/label_filter.go | 347 ++++ .../onsi/ginkgo/v2/types/report_entry.go | 186 ++ .../github.com/onsi/ginkgo/v2/types/types.go | 696 +++++++ .../onsi/ginkgo/v2/types/version.go | 3 + .../qtls-go1-19}/LICENSE | 0 .../github.com/quic-go/qtls-go1-19/README.md | 6 + .../qtls-go1-19}/alert.go | 0 .../qtls-go1-19/auth.go | 0 .../qtls-go1-19}/cfkem.go | 0 .../qtls-go1-19/cipher_suites.go | 0 .../qtls-go1-19/common.go | 4 +- .../qtls-go1-19/conn.go | 56 +- .../qtls-go1-19}/cpu.go | 0 .../qtls-go1-19}/cpu_other.go | 0 .../qtls-go1-19/handshake_client.go | 122 +- .../qtls-go1-19/handshake_client_tls13.go | 75 +- .../qtls-go1-19/handshake_messages.go | 728 +++---- .../qtls-go1-19/handshake_server.go | 73 +- .../qtls-go1-19/handshake_server_tls13.go | 99 +- .../qtls-go1-19}/key_agreement.go | 0 .../qtls-go1-19}/key_schedule.go | 19 +- .../qtls-go1-19/notboring.go | 0 .../qtls-go1-19}/prf.go | 0 .../qtls-go1-19}/ticket.go | 17 +- .../qtls-go1-19}/tls.go | 0 .../qtls-go1-19}/unsafe.go | 0 .../qtls-go1-20}/LICENSE | 0 .../github.com/quic-go/qtls-go1-20/README.md | 6 + .../qtls-go1-20}/alert.go | 0 .../qtls-go1-20}/auth.go | 4 + .../github.com/quic-go/qtls-go1-20/cache.go | 95 + .../qtls-go1-20}/cfkem.go | 13 +- .../qtls-go1-20}/cipher_suites.go | 64 +- .../qtls-go1-20}/common.go | 52 +- .../qtls-go1-20}/conn.go | 117 +- .../qtls-go1-20}/cpu.go | 0 .../qtls-go1-20}/cpu_other.go | 0 .../qtls-go1-20}/handshake_client.go | 156 +- .../qtls-go1-20}/handshake_client_tls13.go | 115 +- .../qtls-go1-20}/handshake_messages.go | 748 +++---- .../qtls-go1-20}/handshake_server.go | 82 +- .../qtls-go1-20}/handshake_server_tls13.go | 121 +- .../qtls-go1-20}/key_agreement.go | 35 +- .../qtls-go1-20}/key_schedule.go | 121 +- .../quic-go/qtls-go1-20/notboring.go | 18 + .../qtls-go1-20}/prf.go | 2 +- .../qtls-go1-20}/ticket.go | 19 +- .../qtls-go1-20}/tls.go | 0 .../qtls-go1-20}/unsafe.go | 0 .../quic-go/.gitignore | 0 .../quic-go/.golangci.yml | 10 +- .../quic-go/Changelog.md | 4 +- .../quic-go/LICENSE | 0 vendor/github.com/quic-go/quic-go/README.md | 64 + vendor/github.com/quic-go/quic-go/SECURITY.md | 19 + .../quic-go/buffer_pool.go | 2 +- vendor/github.com/quic-go/quic-go/client.go | 252 +++ .../github.com/quic-go/quic-go/closed_conn.go | 64 + .../quic-go/codecov.yml | 1 + .../quic-go/config.go | 57 +- .../quic-go/conn_id_generator.go | 36 +- .../quic-go/conn_id_manager.go | 27 +- .../quic-go/connection.go | 710 ++++--- .../quic-go/quic-go/connection_timer.go | 51 + .../quic-go/crypto_stream.go | 12 +- .../quic-go/crypto_stream_manager.go | 4 +- .../quic-go/quic-go/datagram_queue.go | 123 ++ .../quic-go/errors.go | 9 +- .../quic-go/frame_sorter.go | 31 +- .../quic-go/framer.go | 39 +- .../quic-go/interface.go | 133 +- .../internal/ackhandler/ack_eliciting.go | 4 +- .../quic-go/internal/ackhandler/ackhandler.go | 23 + .../quic-go/internal/ackhandler/frame.go | 29 + .../quic-go/internal/ackhandler/interfaces.go | 24 +- .../quic-go/internal/ackhandler/mockgen.go | 6 + .../quic-go/internal/ackhandler/packet.go | 55 + .../ackhandler/packet_number_generator.go | 6 +- .../ackhandler/received_packet_handler.go | 25 +- .../ackhandler/received_packet_history.go | 47 +- .../ackhandler/received_packet_tracker.go | 38 +- .../quic-go/internal/ackhandler/send_mode.go | 0 .../ackhandler/sent_packet_handler.go | 69 +- .../ackhandler/sent_packet_history.go | 90 +- .../quic-go/internal/congestion/bandwidth.go | 2 +- .../quic-go/internal/congestion/clock.go | 0 .../quic-go/internal/congestion/cubic.go | 6 +- .../internal/congestion/cubic_sender.go | 10 +- .../internal/congestion/hybrid_slow_start.go | 8 +- .../quic-go/internal/congestion/interface.go | 2 +- .../quic-go/internal/congestion/pacer.go | 29 +- .../flowcontrol/base_flow_controller.go | 8 +- .../flowcontrol/connection_flow_controller.go | 8 +- .../quic-go/internal/flowcontrol/interface.go | 2 +- .../flowcontrol/stream_flow_controller.go | 8 +- .../quic-go/internal/handshake/aead.go | 8 +- .../internal/handshake/crypto_setup.go | 122 +- .../internal/handshake/header_protector.go | 4 +- .../quic-go/internal/handshake/hkdf.go | 0 .../internal/handshake/initial_aead.go | 8 +- .../quic-go/internal/handshake/interface.go | 14 +- .../quic-go/internal/handshake/mockgen.go | 6 + .../quic-go/internal/handshake/retry.go | 70 + .../internal/handshake/session_ticket.go | 13 +- .../handshake/tls_extension_handler.go | 6 +- .../internal/handshake/token_generator.go | 41 +- .../internal/handshake/token_protector.go | 0 .../internal/handshake/updatable_aead.go | 30 +- .../quic-go/internal/logutils/frame.go | 23 +- .../internal/protocol/connection_id.go | 116 ++ .../internal/protocol/encryption_level.go | 0 .../quic-go/internal/protocol/key_phase.go | 0 .../internal/protocol/packet_number.go | 0 .../quic-go/internal/protocol/params.go | 5 +- .../quic-go/internal/protocol/perspective.go | 0 .../quic-go/internal/protocol/protocol.go | 0 .../quic-go/internal/protocol/stream.go | 0 .../quic-go/internal/protocol/version.go | 19 +- .../quic-go/internal/qerr/error_codes.go | 4 +- .../quic-go/internal/qerr/errors.go | 19 +- .../quic-go/internal/qtls/go119.go} | 55 +- .../quic-go/internal/qtls/go120.go} | 53 +- .../quic-go/quic-go/internal/qtls/go121.go | 5 + .../quic-go/internal/qtls/go_oldversion.go | 5 + .../internal/utils/buffered_write_closer.go | 0 .../quic-go/internal/utils/byteorder.go | 4 + .../internal/utils/byteorder_big_endian.go | 14 + .../quic-go/internal/utils/ip.go | 0 .../internal/utils/linkedlist/README.md | 6 + .../internal/utils/linkedlist/linkedlist.go} | 149 +- .../quic-go/internal/utils/log.go | 2 +- .../quic-go/quic-go/internal/utils/minmax.go | 72 + .../quic-go/internal/utils/rand.go | 0 .../quic-go/internal/utils/rtt_stats.go | 8 +- .../quic-go/internal/utils/timer.go | 4 + .../quic-go/internal/wire/ack_frame.go | 46 +- .../quic-go/internal/wire/ack_frame_pool.go | 24 + .../quic-go/internal/wire/ack_range.go | 2 +- .../internal/wire/connection_close_frame.go | 29 +- .../quic-go/internal/wire/crypto_frame.go | 20 +- .../internal/wire/data_blocked_frame.go | 19 +- .../quic-go/internal/wire/datagram_frame.go | 27 +- .../quic-go/internal/wire/extended_header.go | 210 ++ .../quic-go/internal/wire/frame_parser.go | 185 ++ .../internal/wire/handshake_done_frame.go | 17 + .../quic-go/internal/wire/header.go | 158 +- .../quic-go/internal/wire/interface.go | 8 +- .../quic-go/internal/wire/log.go | 4 +- .../quic-go/internal/wire/max_data_frame.go | 19 +- .../internal/wire/max_stream_data_frame.go | 18 +- .../internal/wire/max_streams_frame.go | 27 +- .../internal/wire/new_connection_id_frame.go | 29 +- .../quic-go/internal/wire/new_token_frame.go | 17 +- .../internal/wire/path_challenge_frame.go | 13 +- .../internal/wire/path_response_frame.go | 13 +- .../quic-go/internal/wire/ping_frame.go | 17 + .../quic-go/internal/wire/pool.go | 2 +- .../internal/wire/reset_stream_frame.go | 22 +- .../wire/retire_connection_id_frame.go | 16 +- .../quic-go/internal/wire/short_header.go | 73 + .../internal/wire/stop_sending_frame.go | 20 +- .../wire/stream_data_blocked_frame.go | 18 +- .../quic-go/internal/wire/stream_frame.go | 41 +- .../internal/wire/streams_blocked_frame.go | 27 +- .../internal/wire/transport_parameters.go | 155 +- .../internal/wire/version_negotiation.go | 39 +- .../quic-go/logging/frame.go | 2 +- .../quic-go/logging/interface.go | 39 +- .../quic-go/quic-go/logging/mockgen.go | 4 + .../quic-go/logging/multiplex.go | 41 +- .../quic-go/quic-go/logging/null_tracer.go | 58 + .../quic-go/logging/packet_header.go | 5 +- .../quic-go/logging/types.go | 0 vendor/github.com/quic-go/quic-go/mockgen.go | 74 + .../quic-go/mtu_discoverer.go | 8 +- .../github.com/quic-go/quic-go/multiplexer.go | 75 + .../quic-go/quic-go/packet_handler_map.go | 271 +++ .../quic-go/quic-go/packet_packer.go | 968 +++++++++ .../quic-go/packet_unpacker.go | 120 +- .../quic-go/quicvarint/io.go | 0 .../quic-go/quicvarint/varint.go | 41 +- .../quic-go/receive_stream.go | 52 +- .../quic-go/retransmission_queue.go | 26 +- .../quic-go/send_conn.go | 23 +- .../quic-go/send_queue.go | 7 + .../quic-go/send_stream.go | 91 +- vendor/github.com/quic-go/quic-go/server.go | 820 ++++++++ .../quic-go/stream.go | 19 +- .../quic-go/streams_map.go | 39 +- .../quic-go/streams_map_incoming.go} | 61 +- .../quic-go/streams_map_outgoing.go} | 68 +- .../quic-go/sys_conn.go | 4 +- .../quic-go/sys_conn_df.go | 1 - .../quic-go/sys_conn_df_linux.go | 5 +- .../quic-go/sys_conn_df_windows.go | 4 +- .../quic-go/sys_conn_helper_darwin.go | 1 - .../quic-go/sys_conn_helper_freebsd.go | 1 - .../quic-go/sys_conn_helper_linux.go | 1 - .../quic-go/sys_conn_no_oob.go | 1 - .../quic-go/sys_conn_oob.go | 55 +- .../quic-go/sys_conn_windows.go | 1 - .../quic-go/token_store.go | 18 +- vendor/github.com/quic-go/quic-go/tools.go | 8 + .../github.com/quic-go/quic-go/transport.go | 414 ++++ .../quic-go/window_update_queue.go | 6 +- .../qtls-go1-18 => golang.org/x/exp}/LICENSE | 0 vendor/golang.org/x/exp/PATENTS | 22 + .../x/exp/constraints/constraints.go | 50 + vendor/golang.org/x/mod/modfile/print.go | 174 ++ vendor/golang.org/x/mod/modfile/read.go | 958 +++++++++ vendor/golang.org/x/mod/modfile/rule.go | 1559 ++++++++++++++ vendor/golang.org/x/mod/modfile/work.go | 234 +++ vendor/gopkg.in/tomb.v1/LICENSE | 29 - vendor/gopkg.in/tomb.v1/README.md | 4 - vendor/gopkg.in/tomb.v1/tomb.go | 176 -- vendor/modules.txt | 129 +- 506 files changed, 26520 insertions(+), 41963 deletions(-) delete mode 100644 vendor/github.com/cheekybits/genny/.gitignore delete mode 100644 vendor/github.com/cheekybits/genny/.travis.yml delete mode 100644 vendor/github.com/cheekybits/genny/LICENSE delete mode 100644 vendor/github.com/cheekybits/genny/README.md delete mode 100644 vendor/github.com/cheekybits/genny/doc.go delete mode 100644 vendor/github.com/cheekybits/genny/generic/doc.go delete mode 100644 vendor/github.com/cheekybits/genny/generic/generic.go delete mode 100644 vendor/github.com/cheekybits/genny/main.go delete mode 100644 vendor/github.com/cheekybits/genny/out/lazy_file.go delete mode 100644 vendor/github.com/cheekybits/genny/parse/builtins.go delete mode 100644 vendor/github.com/cheekybits/genny/parse/doc.go delete mode 100644 vendor/github.com/cheekybits/genny/parse/errors.go delete mode 100644 vendor/github.com/cheekybits/genny/parse/parse.go delete mode 100644 vendor/github.com/cheekybits/genny/parse/typesets.go create mode 100644 vendor/github.com/golang/mock/AUTHORS create mode 100644 vendor/github.com/golang/mock/CONTRIBUTORS create mode 100644 vendor/github.com/golang/mock/LICENSE create mode 100644 vendor/github.com/golang/mock/mockgen/mockgen.go create mode 100644 vendor/github.com/golang/mock/mockgen/model/model.go create mode 100644 vendor/github.com/golang/mock/mockgen/parse.go create mode 100644 vendor/github.com/golang/mock/mockgen/reflect.go create mode 100644 vendor/github.com/golang/mock/mockgen/version.1.11.go create mode 100644 vendor/github.com/golang/mock/mockgen/version.1.12.go create mode 100644 vendor/github.com/google/pprof/AUTHORS create mode 100644 vendor/github.com/google/pprof/CONTRIBUTORS create mode 100644 vendor/github.com/google/pprof/LICENSE create mode 100644 vendor/github.com/google/pprof/profile/encode.go create mode 100644 vendor/github.com/google/pprof/profile/filter.go create mode 100644 vendor/github.com/google/pprof/profile/index.go create mode 100644 vendor/github.com/google/pprof/profile/legacy_java_profile.go create mode 100644 vendor/github.com/google/pprof/profile/legacy_profile.go create mode 100644 vendor/github.com/google/pprof/profile/merge.go create mode 100644 vendor/github.com/google/pprof/profile/profile.go create mode 100644 vendor/github.com/google/pprof/profile/proto.go create mode 100644 vendor/github.com/google/pprof/profile/prune.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/README.md delete mode 100644 vendor/github.com/lucas-clemente/quic-go/client.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/closed_conn.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/datagram_queue.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/mockgen.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh delete mode 100644 vendor/github.com/lucas-clemente/quic-go/multiplexer.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/packet_packer.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/server.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go delete mode 100644 vendor/github.com/lucas-clemente/quic-go/tools.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/README.md delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/auth.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common_js.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/conn.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/ticket.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-16/tls.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/README.md delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/auth.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/common.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/conn.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/README.md delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/alert.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/prf.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/LICENSE delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/README.md delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/alert.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/cpu.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/prf.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/ticket.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/tls.go delete mode 100644 vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go delete mode 100644 vendor/github.com/nxadm/tail/.gitignore delete mode 100644 vendor/github.com/nxadm/tail/CHANGES.md delete mode 100644 vendor/github.com/nxadm/tail/Dockerfile delete mode 100644 vendor/github.com/nxadm/tail/LICENSE delete mode 100644 vendor/github.com/nxadm/tail/README.md delete mode 100644 vendor/github.com/nxadm/tail/ratelimiter/Licence delete mode 100644 vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go delete mode 100644 vendor/github.com/nxadm/tail/ratelimiter/memory.go delete mode 100644 vendor/github.com/nxadm/tail/ratelimiter/storage.go delete mode 100644 vendor/github.com/nxadm/tail/tail.go delete mode 100644 vendor/github.com/nxadm/tail/tail_posix.go delete mode 100644 vendor/github.com/nxadm/tail/tail_windows.go delete mode 100644 vendor/github.com/nxadm/tail/util/util.go delete mode 100644 vendor/github.com/nxadm/tail/watch/filechanges.go delete mode 100644 vendor/github.com/nxadm/tail/watch/inotify.go delete mode 100644 vendor/github.com/nxadm/tail/watch/inotify_tracker.go delete mode 100644 vendor/github.com/nxadm/tail/watch/polling.go delete mode 100644 vendor/github.com/nxadm/tail/watch/watch.go delete mode 100644 vendor/github.com/nxadm/tail/winfile/winfile.go delete mode 100644 vendor/github.com/onsi/ginkgo/config/config.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/build_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/help_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/main.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/notifications.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/run_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/version_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/failer/failer.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/remote/server.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec/spec.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec/specs.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go delete mode 100644 vendor/github.com/onsi/ginkgo/internal/writer/writer.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/default_reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go delete mode 100644 vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go delete mode 100644 vendor/github.com/onsi/ginkgo/types/code_location.go delete mode 100644 vendor/github.com/onsi/ginkgo/types/synchronization.go delete mode 100644 vendor/github.com/onsi/ginkgo/types/types.go rename vendor/github.com/onsi/ginkgo/{ => v2}/LICENSE (100%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/config/deprecated.go rename vendor/github.com/onsi/ginkgo/{reporters/stenographer/support/go-colorable/LICENSE => v2/formatter/colorable_others.go} (76%) rename vendor/github.com/onsi/ginkgo/{reporters/stenographer/support/go-colorable => v2/formatter}/colorable_windows.go (90%) rename vendor/github.com/onsi/ginkgo/{ => v2}/formatter/formatter.go (91%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/command/abort.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/command/program.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/boostrap_templates.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go rename vendor/github.com/onsi/ginkgo/{ginkgo => v2/ginkgo/generators}/generate_command.go (50%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_templates.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/profiles_and_reports.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/test_suite.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/utils.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/labels/labels_command.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/main.go rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/outline/ginkgo.go (82%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/outline/import.go (97%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/outline/outline.go (85%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline_command.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go rename vendor/github.com/onsi/ginkgo/{ginkgo => v2/ginkgo/unfocus}/unfocus_command.go (67%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/delta.go (100%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/delta_tracker.go (85%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/dependencies.go (100%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/package_hash.go (92%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/package_hashes.go (100%) rename vendor/github.com/onsi/ginkgo/{ => v2}/ginkgo/watch/suite.go (90%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go rename vendor/github.com/onsi/ginkgo/{ginkgo/interrupthandler => v2/internal/interrupt_handler}/sigquit_swallower_unix.go (64%) rename vendor/github.com/onsi/ginkgo/{ginkgo/interrupthandler => v2/internal/interrupt_handler}/sigquit_swallower_windows.go (54%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_server.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/json_report.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/code_location.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/config.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/deprecated_types.go rename vendor/github.com/onsi/ginkgo/{ => v2}/types/deprecation_support.go (66%) create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/enum_support.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/errors.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/file_filter.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/flags.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/label_filter.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/report_entry.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/types.go create mode 100644 vendor/github.com/onsi/ginkgo/v2/types/version.go rename vendor/github.com/{marten-seemann/qtls-go1-16 => quic-go/qtls-go1-19}/LICENSE (100%) create mode 100644 vendor/github.com/quic-go/qtls-go1-19/README.md rename vendor/github.com/{marten-seemann/qtls-go1-16 => quic-go/qtls-go1-19}/alert.go (100%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/auth.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-19}/cfkem.go (100%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/cipher_suites.go (100%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/common.go (99%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/conn.go (97%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-19}/cpu.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-19}/cpu_other.go (100%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/handshake_client.go (92%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/handshake_client_tls13.go (92%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/handshake_messages.go (76%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/handshake_server.go (92%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/handshake_server_tls13.go (92%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-19}/key_agreement.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-19}/key_schedule.go (86%) rename vendor/github.com/{marten-seemann => quic-go}/qtls-go1-19/notboring.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-16 => quic-go/qtls-go1-19}/prf.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-19}/ticket.go (96%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-19}/tls.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-16 => quic-go/qtls-go1-19}/unsafe.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-20}/LICENSE (100%) create mode 100644 vendor/github.com/quic-go/qtls-go1-20/README.md rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-20}/alert.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/auth.go (98%) create mode 100644 vendor/github.com/quic-go/qtls-go1-20/cache.go rename vendor/github.com/{marten-seemann/qtls-go1-19 => quic-go/qtls-go1-20}/cfkem.go (94%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/cipher_suites.go (91%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/common.go (97%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/conn.go (95%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/cpu.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/cpu_other.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/handshake_client.go (90%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/handshake_client_tls13.go (89%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/handshake_messages.go (75%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/handshake_server.go (92%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/handshake_server_tls13.go (90%) rename vendor/github.com/{marten-seemann/qtls-go1-19 => quic-go/qtls-go1-20}/key_agreement.go (94%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/key_schedule.go (63%) create mode 100644 vendor/github.com/quic-go/qtls-go1-20/notboring.go rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-20}/prf.go (99%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-20}/ticket.go (95%) rename vendor/github.com/{marten-seemann/qtls-go1-18 => quic-go/qtls-go1-20}/tls.go (100%) rename vendor/github.com/{marten-seemann/qtls-go1-17 => quic-go/qtls-go1-20}/unsafe.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/.gitignore (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/.golangci.yml (73%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/Changelog.md (97%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/LICENSE (100%) create mode 100644 vendor/github.com/quic-go/quic-go/README.md create mode 100644 vendor/github.com/quic-go/quic-go/SECURITY.md rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/buffer_pool.go (96%) create mode 100644 vendor/github.com/quic-go/quic-go/client.go create mode 100644 vendor/github.com/quic-go/quic-go/closed_conn.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/codecov.yml (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/config.go (73%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/conn_id_generator.go (77%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/conn_id_manager.go (90%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/connection.go (75%) create mode 100644 vendor/github.com/quic-go/quic-go/connection_timer.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/crypto_stream.go (88%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/crypto_stream_manager.go (93%) create mode 100644 vendor/github.com/quic-go/quic-go/datagram_queue.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/errors.go (89%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/frame_sorter.go (86%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/framer.go (77%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/interface.go (78%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/ack_eliciting.go (80%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/interfaces.go (71%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/packet_number_generator.go (92%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/received_packet_handler.go (84%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/received_packet_history.go (74%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/received_packet_tracker.go (86%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/send_mode.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/sent_packet_handler.go (91%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/ackhandler/sent_packet_history.go (60%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/bandwidth.go (91%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/clock.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/cubic.go (97%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/cubic_sender.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/hybrid_slow_start.go (91%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/interface.go (94%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/congestion/pacer.go (71%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/flowcontrol/base_flow_controller.go (93%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/flowcontrol/connection_flow_controller.go (94%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/flowcontrol/interface.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/flowcontrol/stream_flow_controller.go (94%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/aead.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/crypto_setup.go (88%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/header_protector.go (97%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/hkdf.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/initial_aead.go (88%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/interface.go (89%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/handshake/retry.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/session_ticket.go (77%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/tls_extension_handler.go (92%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/token_generator.go (76%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/token_protector.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/handshake/updatable_aead.go (92%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/logutils/frame.go (54%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/encryption_level.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/key_phase.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/packet_number.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/params.go (98%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/perspective.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/protocol.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/stream.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/protocol/version.go (81%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/qerr/error_codes.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/qerr/errors.go (84%) rename vendor/github.com/{lucas-clemente/quic-go/internal/qtls/go117.go => quic-go/quic-go/internal/qtls/go119.go} (58%) rename vendor/github.com/{lucas-clemente/quic-go/internal/qtls/go119.go => quic-go/quic-go/internal/qtls/go120.go} (59%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/qtls/go121.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/buffered_write_closer.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/byteorder.go (85%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/byteorder_big_endian.go (85%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/ip.go (100%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md rename vendor/github.com/{lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go => quic-go/quic-go/internal/utils/linkedlist/linkedlist.go} (58%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/log.go (98%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/utils/minmax.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/rand.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/rtt_stats.go (92%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/utils/timer.go (94%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/ack_frame.go (85%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/ack_range.go (82%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/connection_close_frame.go (70%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/crypto_frame.go (86%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/data_blocked_frame.go (53%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/datagram_frame.go (73%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/header.go (60%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/interface.go (53%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/log.go (96%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/max_data_frame.go (55%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/max_stream_data_frame.go (67%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/max_streams_frame.go (61%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/new_connection_id_frame.go (73%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/new_token_frame.go (69%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/path_challenge_frame.go (69%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/path_response_frame.go (68%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/pool.go (90%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/reset_stream_frame.go (68%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/retire_connection_id_frame.go (63%) create mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/short_header.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/stop_sending_frame.go (66%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/stream_data_blocked_frame.go (68%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/stream_frame.go (83%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/streams_blocked_frame.go (60%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/transport_parameters.go (74%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/internal/wire/version_negotiation.go (50%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/logging/frame.go (97%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/logging/interface.go (77%) create mode 100644 vendor/github.com/quic-go/quic-go/logging/mockgen.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/logging/multiplex.go (80%) create mode 100644 vendor/github.com/quic-go/quic-go/logging/null_tracer.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/logging/packet_header.go (83%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/logging/types.go (100%) create mode 100644 vendor/github.com/quic-go/quic-go/mockgen.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/mtu_discoverer.go (89%) create mode 100644 vendor/github.com/quic-go/quic-go/multiplexer.go create mode 100644 vendor/github.com/quic-go/quic-go/packet_handler_map.go create mode 100644 vendor/github.com/quic-go/quic-go/packet_packer.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/packet_unpacker.go (55%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/quicvarint/io.go (100%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/quicvarint/varint.go (73%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/receive_stream.go (86%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/retransmission_queue.go (83%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/send_conn.go (66%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/send_queue.go (93%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/send_stream.go (81%) create mode 100644 vendor/github.com/quic-go/quic-go/server.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/stream.go (89%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/streams_map.go (90%) rename vendor/github.com/{lucas-clemente/quic-go/streams_map_incoming_uni.go => quic-go/quic-go/streams_map_incoming.go} (76%) rename vendor/github.com/{lucas-clemente/quic-go/streams_map_outgoing_uni.go => quic-go/quic-go/streams_map_outgoing.go} (73%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_df.go (90%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_df_linux.go (88%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_df_windows.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_helper_darwin.go (95%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_helper_freebsd.go (93%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_helper_linux.go (96%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_no_oob.go (87%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_oob.go (87%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/sys_conn_windows.go (97%) rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/token_store.go (84%) create mode 100644 vendor/github.com/quic-go/quic-go/tools.go create mode 100644 vendor/github.com/quic-go/quic-go/transport.go rename vendor/github.com/{lucas-clemente => quic-go}/quic-go/window_update_queue.go (91%) rename vendor/{github.com/marten-seemann/qtls-go1-18 => golang.org/x/exp}/LICENSE (100%) create mode 100644 vendor/golang.org/x/exp/PATENTS create mode 100644 vendor/golang.org/x/exp/constraints/constraints.go create mode 100644 vendor/golang.org/x/mod/modfile/print.go create mode 100644 vendor/golang.org/x/mod/modfile/read.go create mode 100644 vendor/golang.org/x/mod/modfile/rule.go create mode 100644 vendor/golang.org/x/mod/modfile/work.go delete mode 100644 vendor/gopkg.in/tomb.v1/LICENSE delete mode 100644 vendor/gopkg.in/tomb.v1/README.md delete mode 100644 vendor/gopkg.in/tomb.v1/tomb.go diff --git a/connection/quic.go b/connection/quic.go index e5e47d2ff07..e211b87b1f4 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -16,8 +16,8 @@ import ( "time" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -67,6 +67,7 @@ type QUICConnection struct { // NewQUICConnection returns a new instance of QUICConnection. func NewQUICConnection( + ctx context.Context, quicConfig *quic.Config, edgeAddr net.Addr, localAddr net.IP, @@ -83,7 +84,7 @@ func NewQUICConnection( return nil, err } - session, err := quic.Dial(udpConn, edgeAddr, edgeAddr.String(), tlsConfig, quicConfig) + session, err := quic.Dial(ctx, udpConn, edgeAddr, tlsConfig, quicConfig) if err != nil { // close the udp server socket in case of error connecting to the edge udpConn.Close() diff --git a/connection/quic_test.go b/connection/quic_test.go index d1dcaa6839a..ce8bd3711dc 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -16,8 +16,8 @@ import ( "github.com/gobwas/ws/wsutil" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,9 +32,8 @@ import ( var ( testTLSServerConfig = quicpogs.GenerateTLSConfig() testQUICConfig = &quic.Config{ - ConnectionIDLength: 16, - KeepAlivePeriod: 5 * time.Second, - EnableDatagrams: true, + KeepAlivePeriod: 5 * time.Second, + EnableDatagrams: true, } ) @@ -43,13 +42,6 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil) // TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol. // It also serves as a demonstration for communication with the QUIC connection started by a cloudflared. func TestQUICServer(t *testing.T) { - // Start a UDP Listener for QUIC. - udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") - require.NoError(t, err) - udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) - require.NoError(t, err) - defer udpListener.Close() - // This is simply a sample websocket frame message. wsBuf := &bytes.Buffer{} wsutil.WriteClientBinary(wsBuf, []byte("Hello")) @@ -145,8 +137,14 @@ func TestQUICServer(t *testing.T) { test := test // capture range variable t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - - quicListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig) + // Start a UDP Listener for QUIC. + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) + require.NoError(t, err) + defer udpListener.Close() + quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16} + quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig) require.NoError(t, err) serverDone := make(chan struct{}) @@ -187,7 +185,7 @@ func (fakeControlStream) IsStopped() bool { func quicServer( ctx context.Context, t *testing.T, - listener quic.Listener, + listener *quic.Listener, dest string, connectionType quicpogs.ConnectionType, metadata []quicpogs.Metadata, @@ -713,7 +711,10 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU } // Start a mock httpProxy log := zerolog.New(os.Stdout) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() qc, err := NewQUICConnection( + ctx, testQUICConfig, udpListenerAddr, nil, diff --git a/datagramsession/session.go b/datagramsession/session.go index 7e0131daeb2..396ee140f29 100644 --- a/datagramsession/session.go +++ b/datagramsession/session.go @@ -51,7 +51,7 @@ type Session struct { func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) { go func() { - // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 // This makes it safe to share readBuffer between iterations const maxPacketSize = 1500 readBuffer := make([]byte, maxPacketSize) diff --git a/go.mod b/go.mod index 7b1549c546f..75ebb1d6f73 100644 --- a/go.mod +++ b/go.mod @@ -20,13 +20,13 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.4.2 github.com/json-iterator/go v1.1.12 - github.com/lucas-clemente/quic-go v0.28.1 github.com/mattn/go-colorable v0.1.13 github.com/miekg/dns v1.1.50 github.com/mitchellh/go-homedir v1.1.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_model v0.2.0 + github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000 github.com/rs/zerolog v1.20.0 github.com/stretchr/testify v1.8.1 github.com/urfave/cli/v2 v2.3.0 @@ -57,7 +57,6 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/cheekybits/genny v1.0.0 // indirect github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47 // indirect github.com/coredns/caddy v1.1.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect @@ -72,29 +71,29 @@ require ( github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect github.com/gobwas/pool v0.2.1 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect github.com/klauspost/compress v1.15.11 // indirect github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect - github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/onsi/ginkgo v1.16.5 // indirect + github.com/onsi/ginkgo/v2 v2.4.0 // indirect github.com/onsi/gomega v1.23.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/quic-go/qtls-go1-19 v0.3.2 // indirect + github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.8.0 // indirect golang.org/x/oauth2 v0.4.0 // indirect golang.org/x/text v0.9.0 // indirect @@ -103,26 +102,21 @@ require ( google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect google.golang.org/grpc v1.51.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect - gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d -replace github.com/lucas-clemente/quic-go => github.com/chungthuang/quic-go v0.27.1-0.20220809135021-ca330f1dec9f - // Avoid 'CVE-2022-21698' replace github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1 replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 +replace github.com/quic-go/quic-go => github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 + // Post-quantum tunnel RTG-1339 replace ( - // Branches go1.18 go1.19 go1.20 on github.com/cloudflare/qtls-pq - github.com/marten-seemann/qtls-go1-18 => github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e - github.com/marten-seemann/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e - github.com/marten-seemann/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230215110727-8b4e1699c2a8 - github.com/quic-go/qtls-go1-18 => github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e - github.com/quic-go/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e - github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230215110727-8b4e1699c2a8 + // Branches go1.19 go1.20 on github.com/cloudflare/qtls-pq + github.com/quic-go/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230320123031-3faac1a945b2 + github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633 ) diff --git a/go.sum b/go.sum index 4b0ec975eb6..8ad33209705 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= @@ -55,12 +53,7 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.22.1/go.mod h1:S8N1cAStu7BOeFfE8KAQzmyyLkK8p/vmRq6kuBTW58Y= -dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= -dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= -dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= -git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.2.0 h1:Rt8g24XnyGTyglgET/PRUNlrUeu9F5L+7FilkXfZgs0= github.com/BurntSushi/toml v1.2.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= @@ -71,7 +64,6 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= -github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apparentlymart/go-cidr v1.1.0 h1:2mAhrMoF+nhXqxTzSZMUzDHkLjmIHC+Zzn4tdgBZjnU= github.com/apparentlymart/go-cidr v1.1.0/go.mod h1:EBcsNrHc3zQeuaeCeCtQruQm+n9/YjEn/vI25Lg7Gwc= @@ -79,8 +71,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= -github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/bwesterb/go-ristretto v1.2.2/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -90,10 +80,6 @@ github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghf github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= -github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= -github.com/chungthuang/quic-go v0.27.1-0.20220809135021-ca330f1dec9f h1:UWC3XjwZzocdNCzzXxq9j/1SdHMZXhcTOsh/+gNRBUQ= -github.com/chungthuang/quic-go v0.27.1-0.20220809135021-ca330f1dec9f/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -104,10 +90,10 @@ github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47 h1:YzpECHxZ9TzO github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47/go.mod h1:qhx8gBILsYlbam7h09SvHDSkjpe3TfLA7b/z4rxJvkE= github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc h1:Dvk3ySBsOm5EviLx6VCyILnafPcQinXGP5jbTdHUJgE= github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc/go.mod h1:HlgKKR8V5a1wroIDDIz3/A+T+9Janfq+7n1P5sEFdi0= -github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e h1:frfo+L0qloEb6Vj+qjS4pbAYSJQZAlUnKZu0uJoErac= -github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e/go.mod h1:mW0BgKFFDAiSmOdUwoORtjo0V2vqw5QzVYRtKQqw/Jg= -github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e h1:RtQDXvDi0PK3EonP0v7zkE5/rApK4MsgRATCdD+ughg= -github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e/go.mod h1:aIsWqC0WXyUiUxBl/RfxAjDyWE9CCLqvSMnCMTd/+bc= +github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633 h1:ZTub2XMOBpxyBiJf6Q+UKqAi07yt1rZmFitriHvFd8M= +github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633/go.mod h1:j/igSUc4PgBMayIsBGjAFu2i7g663rm6kZrKy4htb7E= +github.com/cloudflare/qtls-pq v0.0.0-20230320123031-3faac1a945b2 h1:0/KuLjh9lBMiXlooAdwoo+FbLVD5DABtquB0ImEFOK0= +github.com/cloudflare/qtls-pq v0.0.0-20230320123031-3faac1a945b2/go.mod h1:XzuZIjv4mF5cM205RHHW1d60PQtWGwMR6jx38YKuYHs= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -123,7 +109,6 @@ github.com/coredns/coredns v1.10.0 h1:jCfuWsBjTs0dapkkhISfPCzn5LqvSRtrFtaf/Tjj4D github.com/coredns/coredns v1.10.0/go.mod h1:CIfRU5TgpuoIiJBJ4XrofQzfFQpPFh32ERpUevrSlaw= github.com/coreos/go-oidc/v3 v3.4.0 h1:xz7elHb/LDwm/ERpwHd+5nb7wFHL32rsr6bBOgaeu6g= github.com/coreos/go-oidc/v3 v3.4.0/go.mod h1:eHUXhZtXPQLgEaDrOVTgwbgmz1xGOkJNye6h3zkD2Pw= -github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -134,7 +119,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 h1:qxyoKKPXmPsbvT7SZTcvhEgUaZhEttk4f6u8rIawKj0= +github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -157,8 +143,6 @@ github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQD github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/getsentry/raven-go v0.2.0 h1:no+xWJRb5ZI7eE8TWgIq1jLulQiIoLG0IfYxv5JYMGs= @@ -170,10 +154,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= -github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -222,7 +204,6 @@ github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0L github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -269,8 +250,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= @@ -291,6 +270,7 @@ github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -298,8 +278,6 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= -github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= @@ -308,12 +286,9 @@ github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/Oth github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM= github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= @@ -321,12 +296,10 @@ github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1 github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d h1:PRDnysJ9dF1vUMmEzBu6aHQeUluSQy4eWH3RsSSy/vI= github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= -github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -349,7 +322,6 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -357,13 +329,6 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= -github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= -github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= -github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= -github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ= -github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -371,7 +336,6 @@ github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peK github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= @@ -385,26 +349,12 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= -github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs= +github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys= github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= -github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= @@ -416,7 +366,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= -github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= @@ -429,14 +378,12 @@ github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1: github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/common v0.37.0 h1:ccBbHCgIiT9uSoFY0vX8H3zsNR5eLt17/RQLUvn8pXE= github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= -github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= @@ -449,39 +396,13 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= -github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= -github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= -github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= -github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= -github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= -github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= -github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= -github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= -github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= -github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= -github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= -github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= -github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= -github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= -github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= -github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= -github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= -github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= -github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= -github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= -github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= -github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -497,20 +418,16 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= -github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= -github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -536,17 +453,12 @@ go.opentelemetry.io/proto/otlp v0.15.0 h1:h0bKrvdrT/9sBwEJ6iWUqT/N/xPcS66bL4u3is go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= -go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= -golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= @@ -561,9 +473,10 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= +golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -591,14 +504,10 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -617,7 +526,6 @@ golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= @@ -630,7 +538,6 @@ golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -646,8 +553,6 @@ golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfS golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -670,7 +575,6 @@ golang.org/x/oauth2 v0.0.0-20220608161450-d0670ef3b1eb/go.mod h1:jaDAt6Dkxork7Lm golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094/go.mod h1:h4gKUeWbJ4rQPri7E0u6Gs4e9Ri2zaLxzw5DI5XGrYg= golang.org/x/oauth2 v0.4.0 h1:NF0gk8LVPg1Ml7SSbGyySuoxdsXitj7TvgvuRxIMc/M= golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec= -golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -687,12 +591,9 @@ golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -700,10 +601,8 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -719,7 +618,6 @@ golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -729,7 +627,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -782,13 +679,10 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -833,7 +727,6 @@ golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82u golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -851,9 +744,6 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= @@ -894,8 +784,6 @@ google.golang.org/api v0.78.0/go.mod h1:1Sg78yoMLOhlQTeF+ARBoytAcH1NNyyl390YMy6r google.golang.org/api v0.80.0/go.mod h1:xY3nI94gbvBrE0J6NHXhxOmW97HG7Khjkku6AFB3Hyg= google.golang.org/api v0.84.0/go.mod h1:NTsGnUFJMYROtiquksZHBWtHfeMC7iYthki7Eq3pa8o= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= @@ -904,10 +792,6 @@ google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= -google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -987,9 +871,6 @@ google.golang.org/genproto v0.0.0-20220608133413-ed9918b62aac/go.mod h1:KEWEmljW google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd h1:OjndDrsik+Gt+e6fs45z9AxiewiKyLKYpA45W5Kpkks= google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd/go.mod h1:cTsE614GARnxrLsqKREzmNYJACSWWpAWdNMwnD7c2BE= -google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1050,14 +931,10 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/coreos/go-oidc.v2 v2.2.1 h1:MY5SZClJ7vhjKfr64a4nHAOV/c3WH2gB9BMrR64J1Mc= gopkg.in/coreos/go-oidc.v2 v2.2.1/go.mod h1:fYaTe2FS96wZZwR17YTDHwG+Mw6fmyqJNxN2eNCGPCI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -1069,8 +946,6 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -1083,7 +958,5 @@ nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0 rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= -sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= -sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= zombiezen.com/go/capnproto2 v2.18.0+incompatible h1:mwfXZniffG5mXokQGHUJWGnqIBggoPfT/CEwon9Yess= zombiezen.com/go/capnproto2 v2.18.0+incompatible/go.mod h1:XO5Pr2SbXgqZwn0m0Ru54QBqpOf4K5AYBO+8LAOBQEQ= diff --git a/quic/conversion.go b/quic/conversion.go index 302f97481eb..3a7e3279990 100644 --- a/quic/conversion.go +++ b/quic/conversion.go @@ -4,7 +4,7 @@ import ( "strconv" "time" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/logging" ) func perspectiveString(p logging.Perspective) string { @@ -28,7 +28,7 @@ func durationToPromGauge(duration time.Duration) float64 { return float64(duration.Milliseconds()) } -// Helper to convert https://pkg.go.dev/github.com/lucas-clemente/quic-go@v0.23.0/logging#PacketType into string +// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketType into string func packetTypeString(pt logging.PacketType) string { switch pt { case logging.PacketTypeInitial: @@ -52,7 +52,7 @@ func packetTypeString(pt logging.PacketType) string { } } -// Helper to convert https://pkg.go.dev/github.com/lucas-clemente/quic-go@v0.23.0/logging#PacketDropReason into string +// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketDropReason into string func packetDropReasonString(reason logging.PacketDropReason) string { switch reason { case logging.PacketDropKeyUnavailable: @@ -82,7 +82,7 @@ func packetDropReasonString(reason logging.PacketDropReason) string { } } -// Helper to convert https://pkg.go.dev/github.com/lucas-clemente/quic-go@v0.23.0/logging#PacketLossReason into string +// Helper to convert https://pkg.go.dev/github.com/quic-go/quic-go@v0.23.0/logging#PacketLossReason into string func packetLossReasonString(reason logging.PacketLossReason) string { switch reason { case logging.PacketLossReorderingThreshold: diff --git a/quic/datagram.go b/quic/datagram.go index c4637dc0ec2..8b1f8a92c0b 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/packet" diff --git a/quic/datagram_test.go b/quic/datagram_test.go index 54307c79a75..5576fb31dd4 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -16,7 +16,7 @@ import ( "github.com/google/gopacket/layers" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/require" "golang.org/x/net/icmp" @@ -180,8 +180,10 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Establish quic connection - quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig) + quicSession, err := quic.DialAddrEarly(ctx, quicListener.Addr().String(), tlsClientConfig, quicConfig) require.NoError(t, err) defer quicSession.CloseWithError(0, "") @@ -264,7 +266,7 @@ func validateTracingSpans(t *testing.T, receivedPacket Packet, expectedSpan *Tra require.Equal(t, tracingSpans, expectedSpan) } -func newQUICListener(t *testing.T, config *quic.Config) quic.Listener { +func newQUICListener(t *testing.T, config *quic.Config) *quic.Listener { // Create a simple tls config. tlsConfig := generateTLSConfig() diff --git a/quic/datagramv2.go b/quic/datagramv2.go index 010126189eb..e4320bcad9d 100644 --- a/quic/datagramv2.go +++ b/quic/datagramv2.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/packet" diff --git a/quic/metrics.go b/quic/metrics.go index 900ac172b3e..90c131c0fe7 100644 --- a/quic/metrics.go +++ b/quic/metrics.go @@ -3,8 +3,8 @@ package quic import ( "sync" - "github.com/lucas-clemente/quic-go/logging" "github.com/prometheus/client_golang/prometheus" + "github.com/quic-go/quic-go/logging" ) const ( diff --git a/quic/param_windows.go b/quic/param_windows.go index 13b29ed013d..fc12bc566fc 100644 --- a/quic/param_windows.go +++ b/quic/param_windows.go @@ -3,9 +3,9 @@ package quic const ( - // Due to https://github.com/lucas-clemente/quic-go/issues/3273, MTU discovery is disabled on Windows - // 1220 is the default value https://github.com/lucas-clemente/quic-go/blob/84e03e59760ceee37359688871bb0688fcc4e98f/internal/protocol/params.go#L138 + // Due to https://github.com/quic-go/quic-go/issues/3273, MTU discovery is disabled on Windows + // 1220 is the default value https://github.com/quic-go/quic-go/blob/84e03e59760ceee37359688871bb0688fcc4e98f/internal/protocol/params.go#L138 MaxDatagramFrameSize = 1220 - // 3 more bytes are reserved at https://github.com/lucas-clemente/quic-go/blob/v0.24.0/internal/wire/datagram_frame.go#L61 + // 3 more bytes are reserved at https://github.com/quic-go/quic-go/blob/v0.24.0/internal/wire/datagram_frame.go#L61 maxDatagramPayloadSize = MaxDatagramFrameSize - 3 - sessionIDLen - typeIDLen ) diff --git a/quic/safe_stream.go b/quic/safe_stream.go index e19faa5ca77..6fa9825986c 100644 --- a/quic/safe_stream.go +++ b/quic/safe_stream.go @@ -4,7 +4,7 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" ) type SafeStreamCloser struct { diff --git a/quic/safe_stream_test.go b/quic/safe_stream_test.go index b4c096b2cba..7cfecc84608 100644 --- a/quic/safe_stream_test.go +++ b/quic/safe_stream_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "github.com/stretchr/testify/require" ) @@ -56,7 +56,9 @@ func quicClient(t *testing.T, addr net.Addr) { InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, } - session, err := quic.DialAddr(addr.String(), tlsConf, testQUICConfig) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + session, err := quic.DialAddr(ctx, addr.String(), tlsConf, testQUICConfig) require.NoError(t, err) var wg sync.WaitGroup diff --git a/quic/tracing.go b/quic/tracing.go index ca14100f02f..10a59b9abd0 100644 --- a/quic/tracing.go +++ b/quic/tracing.go @@ -5,7 +5,7 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/logging" "github.com/rs/zerolog" ) @@ -21,14 +21,15 @@ type tracerConfig struct { index uint8 } -func NewClientTracer(logger *zerolog.Logger, index uint8) logging.Tracer { - return &tracer{ +func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer { + t := &tracer{ logger: logger, config: &tracerConfig{ isClient: true, index: index, }, } + return t.TracerForConnection } func NewServerTracer(logger *zerolog.Logger) logging.Tracer { @@ -47,7 +48,10 @@ func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspectiv return newConnTracer(newServiceCollector()) } -func (*tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {} +func (*tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) { +} +func (*tracer) SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { +} func (*tracer) DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) { } @@ -82,7 +86,7 @@ func (ct *connTracer) ReceivedPacket(hdr *logging.ExtendedHeader, size logging.B ct.metricsCollector.receivedPackets(size) } -func (ct *connTracer) BufferedPacket(pt logging.PacketType) { +func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCount) { ct.metricsCollector.bufferedPackets(pt) } @@ -110,12 +114,24 @@ func (ct *connTracer) ReceivedTransportParameters(parameters *logging.TransportP func (ct *connTracer) RestoredTransportParameters(parameters *logging.TransportParameters) { } -func (ct *connTracer) ReceivedVersionNegotiationPacket(header *logging.Header, numbers []logging.VersionNumber) { +func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { +} + +func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) { +} + +func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { } func (ct *connTracer) ReceivedRetry(header *logging.Header) { } +func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, frames []logging.Frame) { +} + +func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) { +} + func (ct *connTracer) AcknowledgedPacket(level logging.EncryptionLevel, number logging.PacketNumber) { } diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index d345fe7c40b..60e915bf104 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -7,7 +7,7 @@ import ( "strings" "time" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/connection" diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 1b2e132fd74..d2abce0769b 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -11,8 +11,8 @@ import ( "time" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" @@ -605,6 +605,7 @@ func (e *EdgeTunnelServer) serveQUIC( } quicConn, err := connection.NewQUICConnection( + ctx, quicConfig, edgeAddr, e.edgeBindAddr, diff --git a/supervisor/tunnel_test.go b/supervisor/tunnel_test.go index 9a0e9653ed9..d83b5084c5e 100644 --- a/supervisor/tunnel_test.go +++ b/supervisor/tunnel_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" diff --git a/vendor/github.com/cheekybits/genny/.gitignore b/vendor/github.com/cheekybits/genny/.gitignore deleted file mode 100644 index c62d148c2bb..00000000000 --- a/vendor/github.com/cheekybits/genny/.gitignore +++ /dev/null @@ -1,26 +0,0 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe -*.test -*.prof - -genny diff --git a/vendor/github.com/cheekybits/genny/.travis.yml b/vendor/github.com/cheekybits/genny/.travis.yml deleted file mode 100644 index 78ba5f2d102..00000000000 --- a/vendor/github.com/cheekybits/genny/.travis.yml +++ /dev/null @@ -1,6 +0,0 @@ -language: go - -go: - - 1.7 - - 1.8 - - 1.9 diff --git a/vendor/github.com/cheekybits/genny/LICENSE b/vendor/github.com/cheekybits/genny/LICENSE deleted file mode 100644 index 519d7f22729..00000000000 --- a/vendor/github.com/cheekybits/genny/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2014 cheekybits - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - diff --git a/vendor/github.com/cheekybits/genny/README.md b/vendor/github.com/cheekybits/genny/README.md deleted file mode 100644 index 64a28ac724d..00000000000 --- a/vendor/github.com/cheekybits/genny/README.md +++ /dev/null @@ -1,245 +0,0 @@ -# genny - Generics for Go - -[![Build Status](https://travis-ci.org/cheekybits/genny.svg?branch=master)](https://travis-ci.org/cheekybits/genny) [![GoDoc](https://godoc.org/github.com/cheekybits/genny/parse?status.png)](http://godoc.org/github.com/cheekybits/genny/parse) - -Install: - -``` -go get github.com/cheekybits/genny -``` - -===== - -(pron. Jenny) by Mat Ryer ([@matryer](https://twitter.com/matryer)) and Tyler Bunnell ([@TylerJBunnell](https://twitter.com/TylerJBunnell)). - -Until the Go core team include support for [generics in Go](http://golang.org/doc/faq#generics), `genny` is a code-generation generics solution. It allows you write normal buildable and testable Go code which, when processed by the `genny gen` tool, will replace the generics with specific types. - - * Generic code is valid Go code - * Generic code compiles and can be tested - * Use `stdin` and `stdout` or specify in and out files - * Supports Go 1.4's [go generate](http://tip.golang.org/doc/go1.4#gogenerate) - * Multiple specific types will generate every permutation - * Use `BUILTINS` and `NUMBERS` wildtype to generate specific code for all built-in (and number) Go types - * Function names and comments also get updated - -## Library - -We have started building a [library of common things](https://github.com/cheekybits/gennylib), and you can use `genny get` to generate the specific versions you need. - -For example: `genny get maps/concurrentmap.go "KeyType=BUILTINS ValueType=BUILTINS"` will print out generated code for all types for a concurrent map. Any file in the library may be generated locally in this way using all the same options given to `genny gen`. - -## Usage - -``` -genny [{flags}] gen "{types}" - -gen - generates type specific code from generic code. -get - fetch a generic template from the online library and gen it. - -{flags} - (optional) Command line flags (see below) -{types} - (required) Specific types for each generic type in the source -{types} format: {generic}={specific}[,another][ {generic2}={specific2}] - -Examples: - Generic=Specific - Generic1=Specific1 Generic2=Specific2 - Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 - -Flags: - -in="": file to parse instead of stdin - -out="": file to save output to instead of stdout - -pkg="": package name for generated files -``` - - * Comma separated type lists will generate code for each type - -### Flags - - * `-in` - specify the input file (rather than using stdin) - * `-out` - specify the output file (rather than using stdout) - -### go generate - -To use Go 1.4's `go generate` capability, insert the following comment in your source code file: - -``` -//go:generate genny -in=$GOFILE -out=gen-$GOFILE gen "KeyType=string,int ValueType=string,int" -``` - - * Start the line with `//go:generate ` - * Use the `-in` and `-out` flags to specify the files to work on - * Use the `genny` command as usual after the flags - -Now, running `go generate` (in a shell) for the package will cause the generic versions of the files to be generated. - - * The output file will be overwritten, so it's safe to call `go generate` many times - * Use `$GOFILE` to refer to the current file - * The `//go:generate` line will be removed from the output - -To see a real example of how to use `genny` with `go generate`, look in the [example/go-generate directory](https://github.com/cheekybits/genny/tree/master/examples/go-generate). - -## How it works - -Define your generic types using the special `generic.Type` placeholder type: - -```go -type KeyType generic.Type -type ValueType generic.Type -``` - - * You can use as many as you like - * Give them meaningful names - -Then write the generic code referencing the types as your normally would: - -```go -func SetValueTypeForKeyType(key KeyType, value ValueType) { /* ... */ } -``` - - * Generic type names will also be replaced in comments and function names (see Real example below) - -Since `generic.Type` is a real Go type, your code will compile, and you can even write unit tests against your generic code. - -#### Generating specific versions - -Pass the file through the `genny gen` tool with the specific types as the argument: - -``` -cat generic.go | genny gen "KeyType=string ValueType=interface{}" -``` - -The output will be the complete Go source file with the generic types replaced with the types specified in the arguments. - -## Real example - -Given [this generic Go code](https://github.com/cheekybits/genny/tree/master/examples/queue) which compiles and is tested: - -```go -package queue - -import "github.com/cheekybits/genny/generic" - -// NOTE: this is how easy it is to define a generic type -type Something generic.Type - -// SomethingQueue is a queue of Somethings. -type SomethingQueue struct { - items []Something -} - -func NewSomethingQueue() *SomethingQueue { - return &SomethingQueue{items: make([]Something, 0)} -} -func (q *SomethingQueue) Push(item Something) { - q.items = append(q.items, item) -} -func (q *SomethingQueue) Pop() Something { - item := q.items[0] - q.items = q.items[1:] - return item -} -``` - -When `genny gen` is invoked like this: - -``` -cat source.go | genny gen "Something=string" -``` - -It outputs: - -```go -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package queue - -// StringQueue is a queue of Strings. -type StringQueue struct { - items []string -} - -func NewStringQueue() *StringQueue { - return &StringQueue{items: make([]string, 0)} -} -func (q *StringQueue) Push(item string) { - q.items = append(q.items, item) -} -func (q *StringQueue) Pop() string { - item := q.items[0] - q.items = q.items[1:] - return item -} -``` - -To get a _something_ for every built-in Go type plus one of your own types, you could run: - -``` -cat source.go | genny gen "Something=BUILTINS,*MyType" -``` - -#### More examples - -Check out the [test code files](https://github.com/cheekybits/genny/tree/master/parse/test) for more real examples. - -## Writing test code - -Once you have defined a generic type with some code worth testing: - -```go -package slice - -import ( - "log" - "reflect" - - "github.com/stretchr/gogen/generic" -) - -type MyType generic.Type - -func EnsureMyTypeSlice(objectOrSlice interface{}) []MyType { - log.Printf("%v", reflect.TypeOf(objectOrSlice)) - switch obj := objectOrSlice.(type) { - case []MyType: - log.Println(" returning it untouched") - return obj - case MyType: - log.Println(" wrapping in slice") - return []MyType{obj} - default: - panic("ensure slice needs MyType or []MyType") - } -} -``` - -You can treat it like any normal Go type in your test code: - -```go -func TestEnsureMyTypeSlice(t *testing.T) { - - myType := new(MyType) - slice := EnsureMyTypeSlice(myType) - if assert.NotNil(t, slice) { - assert.Equal(t, slice[0], myType) - } - - slice = EnsureMyTypeSlice(slice) - log.Printf("%#v", slice[0]) - if assert.NotNil(t, slice) { - assert.Equal(t, slice[0], myType) - } - -} -``` - -### Understanding what `generic.Type` is - -Because `generic.Type` is an empty interface type (literally `interface{}`) every other type will be considered to be a `generic.Type` if you are switching on the type of an object. Of course, once the specific versions are generated, this issue goes away but it's worth knowing when you are writing your tests against generic code. - -### Contributions - - * See the [API documentation for the parse package](http://godoc.org/github.com/cheekybits/genny/parse) - * Please do TDD - * All input welcome diff --git a/vendor/github.com/cheekybits/genny/doc.go b/vendor/github.com/cheekybits/genny/doc.go deleted file mode 100644 index 4c31e22bc4a..00000000000 --- a/vendor/github.com/cheekybits/genny/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package main is the command line tool for Genny. -package main diff --git a/vendor/github.com/cheekybits/genny/generic/doc.go b/vendor/github.com/cheekybits/genny/generic/doc.go deleted file mode 100644 index 3bd6c869c0f..00000000000 --- a/vendor/github.com/cheekybits/genny/generic/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package generic contains the generic marker types. -package generic diff --git a/vendor/github.com/cheekybits/genny/generic/generic.go b/vendor/github.com/cheekybits/genny/generic/generic.go deleted file mode 100644 index 04a2306cbf1..00000000000 --- a/vendor/github.com/cheekybits/genny/generic/generic.go +++ /dev/null @@ -1,13 +0,0 @@ -package generic - -// Type is the placeholder type that indicates a generic value. -// When genny is executed, variables of this type will be replaced with -// references to the specific types. -// var GenericType generic.Type -type Type interface{} - -// Number is the placehoder type that indiccates a generic numerical value. -// When genny is executed, variables of this type will be replaced with -// references to the specific types. -// var GenericType generic.Number -type Number float64 diff --git a/vendor/github.com/cheekybits/genny/main.go b/vendor/github.com/cheekybits/genny/main.go deleted file mode 100644 index fe06a6c034a..00000000000 --- a/vendor/github.com/cheekybits/genny/main.go +++ /dev/null @@ -1,154 +0,0 @@ -package main - -import ( - "bytes" - "flag" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "strings" - - "github.com/cheekybits/genny/out" - "github.com/cheekybits/genny/parse" -) - -/* - - source | genny gen [-in=""] [-out=""] [-pkg=""] "KeyType=string,int ValueType=string,int" - -*/ - -const ( - _ = iota - exitcodeInvalidArgs - exitcodeInvalidTypeSet - exitcodeStdinFailed - exitcodeGenFailed - exitcodeGetFailed - exitcodeSourceFileInvalid - exitcodeDestFileFailed -) - -func main() { - var ( - in = flag.String("in", "", "file to parse instead of stdin") - out = flag.String("out", "", "file to save output to instead of stdout") - pkgName = flag.String("pkg", "", "package name for generated files") - prefix = "https://github.com/metabition/gennylib/raw/master/" - ) - flag.Parse() - args := flag.Args() - - if len(args) < 2 { - usage() - os.Exit(exitcodeInvalidArgs) - } - - if strings.ToLower(args[0]) != "gen" && strings.ToLower(args[0]) != "get" { - usage() - os.Exit(exitcodeInvalidArgs) - } - - // parse the typesets - var setsArg = args[1] - if strings.ToLower(args[0]) == "get" { - setsArg = args[2] - } - typeSets, err := parse.TypeSet(setsArg) - if err != nil { - fatal(exitcodeInvalidTypeSet, err) - } - - outWriter := newWriter(*out) - - if strings.ToLower(args[0]) == "get" { - if len(args) != 3 { - fmt.Println("not enough arguments to get") - usage() - os.Exit(exitcodeInvalidArgs) - } - r, err := http.Get(prefix + args[1]) - if err != nil { - fatal(exitcodeGetFailed, err) - } - b, err := ioutil.ReadAll(r.Body) - if err != nil { - fatal(exitcodeGetFailed, err) - } - r.Body.Close() - br := bytes.NewReader(b) - err = gen(*in, *pkgName, br, typeSets, outWriter) - } else if len(*in) > 0 { - var file *os.File - file, err = os.Open(*in) - if err != nil { - fatal(exitcodeSourceFileInvalid, err) - } - defer file.Close() - err = gen(*in, *pkgName, file, typeSets, outWriter) - } else { - var source []byte - source, err = ioutil.ReadAll(os.Stdin) - if err != nil { - fatal(exitcodeStdinFailed, err) - } - reader := bytes.NewReader(source) - err = gen("stdin", *pkgName, reader, typeSets, outWriter) - } - - // do the work - if err != nil { - fatal(exitcodeGenFailed, err) - } - -} - -func usage() { - fmt.Fprintln(os.Stderr, `usage: genny [{flags}] gen "{types}" - -gen - generates type specific code from generic code. -get - fetch a generic template from the online library and gen it. - -{flags} - (optional) Command line flags (see below) -{types} - (required) Specific types for each generic type in the source -{types} format: {generic}={specific}[,another][ {generic2}={specific2}] - -Examples: - Generic=Specific - Generic1=Specific1 Generic2=Specific2 - Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 - -Flags:`) - flag.PrintDefaults() -} - -func newWriter(fileName string) io.Writer { - if fileName == "" { - return os.Stdout - } - lf := &out.LazyFile{FileName: fileName} - defer lf.Close() - return lf -} - -func fatal(code int, a ...interface{}) { - fmt.Println(a...) - os.Exit(code) -} - -// gen performs the generic generation. -func gen(filename, pkgName string, in io.ReadSeeker, typesets []map[string]string, out io.Writer) error { - - var output []byte - var err error - - output, err = parse.Generics(filename, pkgName, in, typesets) - if err != nil { - return err - } - - out.Write(output) - return nil -} diff --git a/vendor/github.com/cheekybits/genny/out/lazy_file.go b/vendor/github.com/cheekybits/genny/out/lazy_file.go deleted file mode 100644 index 7c8815f5f97..00000000000 --- a/vendor/github.com/cheekybits/genny/out/lazy_file.go +++ /dev/null @@ -1,38 +0,0 @@ -package out - -import ( - "os" - "path" -) - -// LazyFile is an io.WriteCloser which defers creation of the file it is supposed to write in -// till the first call to its write function in order to prevent creation of file, if no write -// is supposed to happen. -type LazyFile struct { - // FileName is path to the file to which genny will write. - FileName string - file *os.File -} - -// Close closes the file if it is created. Returns nil if no file is created. -func (lw *LazyFile) Close() error { - if lw.file != nil { - return lw.file.Close() - } - return nil -} - -// Write writes to the specified file and creates the file first time it is called. -func (lw *LazyFile) Write(p []byte) (int, error) { - if lw.file == nil { - err := os.MkdirAll(path.Dir(lw.FileName), 0755) - if err != nil { - return 0, err - } - lw.file, err = os.Create(lw.FileName) - if err != nil { - return 0, err - } - } - return lw.file.Write(p) -} diff --git a/vendor/github.com/cheekybits/genny/parse/builtins.go b/vendor/github.com/cheekybits/genny/parse/builtins.go deleted file mode 100644 index e02995444c8..00000000000 --- a/vendor/github.com/cheekybits/genny/parse/builtins.go +++ /dev/null @@ -1,41 +0,0 @@ -package parse - -// Builtins contains a slice of all built-in Go types. -var Builtins = []string{ - "bool", - "byte", - "complex128", - "complex64", - "error", - "float32", - "float64", - "int", - "int16", - "int32", - "int64", - "int8", - "rune", - "string", - "uint", - "uint16", - "uint32", - "uint64", - "uint8", - "uintptr", -} - -// Numbers contains a slice of all built-in number types. -var Numbers = []string{ - "float32", - "float64", - "int", - "int16", - "int32", - "int64", - "int8", - "uint", - "uint16", - "uint32", - "uint64", - "uint8", -} diff --git a/vendor/github.com/cheekybits/genny/parse/doc.go b/vendor/github.com/cheekybits/genny/parse/doc.go deleted file mode 100644 index 1be4fed8b47..00000000000 --- a/vendor/github.com/cheekybits/genny/parse/doc.go +++ /dev/null @@ -1,14 +0,0 @@ -// Package parse contains the generic code generation capabilities -// that power genny. -// -// genny gen "{types}" -// -// gen - generates type specific code (to stdout) from generic code (via stdin) -// -// {types} - (required) Specific types for each generic type in the source -// {types} format: {generic}={specific}[,another][ {generic2}={specific2}] -// Examples: -// Generic=Specific -// Generic1=Specific1 Generic2=Specific2 -// Generic1=Specific1,Specific2 Generic2=Specific3,Specific4 -package parse diff --git a/vendor/github.com/cheekybits/genny/parse/errors.go b/vendor/github.com/cheekybits/genny/parse/errors.go deleted file mode 100644 index ab812bf9083..00000000000 --- a/vendor/github.com/cheekybits/genny/parse/errors.go +++ /dev/null @@ -1,47 +0,0 @@ -package parse - -import ( - "errors" -) - -// errMissingSpecificType represents an error when a generic type is not -// satisfied by a specific type. -type errMissingSpecificType struct { - GenericType string -} - -// Error gets a human readable string describing this error. -func (e errMissingSpecificType) Error() string { - return "Missing specific type for '" + e.GenericType + "' generic type" -} - -// errImports represents an error from goimports. -type errImports struct { - Err error -} - -// Error gets a human readable string describing this error. -func (e errImports) Error() string { - return "Failed to goimports the generated code: " + e.Err.Error() -} - -// errSource represents an error with the source file. -type errSource struct { - Err error -} - -// Error gets a human readable string describing this error. -func (e errSource) Error() string { - return "Failed to parse source file: " + e.Err.Error() -} - -type errBadTypeArgs struct { - Message string - Arg string -} - -func (e errBadTypeArgs) Error() string { - return "\"" + e.Arg + "\" is bad: " + e.Message -} - -var errMissingTypeInformation = errors.New("No type arguments were specified and no \"// +gogen\" tag was found in the source.") diff --git a/vendor/github.com/cheekybits/genny/parse/parse.go b/vendor/github.com/cheekybits/genny/parse/parse.go deleted file mode 100644 index 08eb48b1136..00000000000 --- a/vendor/github.com/cheekybits/genny/parse/parse.go +++ /dev/null @@ -1,298 +0,0 @@ -package parse - -import ( - "bufio" - "bytes" - "fmt" - "go/ast" - "go/parser" - "go/scanner" - "go/token" - "io" - "os" - "strings" - "unicode" - - "golang.org/x/tools/imports" -) - -var header = []byte(` - -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -`) - -var ( - packageKeyword = []byte("package") - importKeyword = []byte("import") - openBrace = []byte("(") - closeBrace = []byte(")") - genericPackage = "generic" - genericType = "generic.Type" - genericNumber = "generic.Number" - linefeed = "\r\n" -) -var unwantedLinePrefixes = [][]byte{ - []byte("//go:generate genny "), -} - -func subIntoLiteral(lit, typeTemplate, specificType string) string { - if lit == typeTemplate { - return specificType - } - if !strings.Contains(lit, typeTemplate) { - return lit - } - specificLg := wordify(specificType, true) - specificSm := wordify(specificType, false) - result := strings.Replace(lit, typeTemplate, specificLg, -1) - if strings.HasPrefix(result, specificLg) && !isExported(lit) { - return strings.Replace(result, specificLg, specificSm, 1) - } - return result -} - -func subTypeIntoComment(line, typeTemplate, specificType string) string { - var subbed string - for _, w := range strings.Fields(line) { - subbed = subbed + subIntoLiteral(w, typeTemplate, specificType) + " " - } - return subbed -} - -// Does the heavy lifting of taking a line of our code and -// sbustituting a type into there for our generic type -func subTypeIntoLine(line, typeTemplate, specificType string) string { - src := []byte(line) - var s scanner.Scanner - fset := token.NewFileSet() - file := fset.AddFile("", fset.Base(), len(src)) - s.Init(file, src, nil, scanner.ScanComments) - output := "" - for { - _, tok, lit := s.Scan() - if tok == token.EOF { - break - } else if tok == token.COMMENT { - subbed := subTypeIntoComment(lit, typeTemplate, specificType) - output = output + subbed + " " - } else if tok.IsLiteral() { - subbed := subIntoLiteral(lit, typeTemplate, specificType) - output = output + subbed + " " - } else { - output = output + tok.String() + " " - } - } - return output -} - -// typeSet looks like "KeyType: int, ValueType: string" -func generateSpecific(filename string, in io.ReadSeeker, typeSet map[string]string) ([]byte, error) { - - // ensure we are at the beginning of the file - in.Seek(0, os.SEEK_SET) - - // parse the source file - fs := token.NewFileSet() - file, err := parser.ParseFile(fs, filename, in, 0) - if err != nil { - return nil, &errSource{Err: err} - } - - // make sure every generic.Type is represented in the types - // argument. - for _, decl := range file.Decls { - switch it := decl.(type) { - case *ast.GenDecl: - for _, spec := range it.Specs { - ts, ok := spec.(*ast.TypeSpec) - if !ok { - continue - } - switch tt := ts.Type.(type) { - case *ast.SelectorExpr: - if name, ok := tt.X.(*ast.Ident); ok { - if name.Name == genericPackage { - if _, ok := typeSet[ts.Name.Name]; !ok { - return nil, &errMissingSpecificType{GenericType: ts.Name.Name} - } - } - } - } - } - } - } - - in.Seek(0, os.SEEK_SET) - - var buf bytes.Buffer - - comment := "" - scanner := bufio.NewScanner(in) - for scanner.Scan() { - - line := scanner.Text() - - // does this line contain generic.Type? - if strings.Contains(line, genericType) || strings.Contains(line, genericNumber) { - comment = "" - continue - } - - for t, specificType := range typeSet { - if strings.Contains(line, t) { - newLine := subTypeIntoLine(line, t, specificType) - line = newLine - } - } - - if comment != "" { - buf.WriteString(makeLine(comment)) - comment = "" - } - - // is this line a comment? - // TODO: should we handle /* */ comments? - if strings.HasPrefix(line, "//") { - // record this line to print later - comment = line - continue - } - - // write the line - buf.WriteString(makeLine(line)) - } - - // write it out - return buf.Bytes(), nil -} - -// Generics parses the source file and generates the bytes replacing the -// generic types for the keys map with the specific types (its value). -func Generics(filename, pkgName string, in io.ReadSeeker, typeSets []map[string]string) ([]byte, error) { - - totalOutput := header - - for _, typeSet := range typeSets { - - // generate the specifics - parsed, err := generateSpecific(filename, in, typeSet) - if err != nil { - return nil, err - } - - totalOutput = append(totalOutput, parsed...) - - } - - // clean up the code line by line - packageFound := false - insideImportBlock := false - var cleanOutputLines []string - scanner := bufio.NewScanner(bytes.NewReader(totalOutput)) - for scanner.Scan() { - - // end of imports block? - if insideImportBlock { - if bytes.HasSuffix(scanner.Bytes(), closeBrace) { - insideImportBlock = false - } - continue - } - - if bytes.HasPrefix(scanner.Bytes(), packageKeyword) { - if packageFound { - continue - } else { - packageFound = true - } - } else if bytes.HasPrefix(scanner.Bytes(), importKeyword) { - if bytes.HasSuffix(scanner.Bytes(), openBrace) { - insideImportBlock = true - } - continue - } - - // check all unwantedLinePrefixes - and skip them - skipline := false - for _, prefix := range unwantedLinePrefixes { - if bytes.HasPrefix(scanner.Bytes(), prefix) { - skipline = true - continue - } - } - - if skipline { - continue - } - - cleanOutputLines = append(cleanOutputLines, makeLine(scanner.Text())) - } - - cleanOutput := strings.Join(cleanOutputLines, "") - - output := []byte(cleanOutput) - var err error - - // change package name - if pkgName != "" { - output = changePackage(bytes.NewReader([]byte(output)), pkgName) - } - // fix the imports - output, err = imports.Process(filename, output, nil) - if err != nil { - return nil, &errImports{Err: err} - } - - return output, nil -} - -func makeLine(s string) string { - return fmt.Sprintln(strings.TrimRight(s, linefeed)) -} - -// isAlphaNumeric gets whether the rune is alphanumeric or _. -func isAlphaNumeric(r rune) bool { - return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) -} - -// wordify turns a type into a nice word for function and type -// names etc. -func wordify(s string, exported bool) string { - s = strings.TrimRight(s, "{}") - s = strings.TrimLeft(s, "*&") - s = strings.Replace(s, ".", "", -1) - if !exported { - return s - } - return strings.ToUpper(string(s[0])) + s[1:] -} - -func changePackage(r io.Reader, pkgName string) []byte { - var out bytes.Buffer - sc := bufio.NewScanner(r) - done := false - - for sc.Scan() { - s := sc.Text() - - if !done && strings.HasPrefix(s, "package") { - parts := strings.Split(s, " ") - parts[1] = pkgName - s = strings.Join(parts, " ") - done = true - } - - fmt.Fprintln(&out, s) - } - return out.Bytes() -} - -func isExported(lit string) bool { - if len(lit) == 0 { - return false - } - return unicode.IsUpper(rune(lit[0])) -} diff --git a/vendor/github.com/cheekybits/genny/parse/typesets.go b/vendor/github.com/cheekybits/genny/parse/typesets.go deleted file mode 100644 index c30b97289ad..00000000000 --- a/vendor/github.com/cheekybits/genny/parse/typesets.go +++ /dev/null @@ -1,89 +0,0 @@ -package parse - -import "strings" - -const ( - typeSep = " " - keyValueSep = "=" - valuesSep = "," - builtins = "BUILTINS" - numbers = "NUMBERS" -) - -// TypeSet turns a type string into a []map[string]string -// that can be given to parse.Generics for it to do its magic. -// -// Acceptable args are: -// -// Person=man -// Person=man Animal=dog -// Person=man Animal=dog Animal2=cat -// Person=man,woman Animal=dog,cat -// Person=man,woman,child Animal=dog,cat Place=london,paris -func TypeSet(arg string) ([]map[string]string, error) { - - types := make(map[string][]string) - var keys []string - for _, pair := range strings.Split(arg, typeSep) { - segs := strings.Split(pair, keyValueSep) - if len(segs) != 2 { - return nil, &errBadTypeArgs{Arg: arg, Message: "Generic=Specific expected"} - } - key := segs[0] - keys = append(keys, key) - types[key] = make([]string, 0) - for _, t := range strings.Split(segs[1], valuesSep) { - if t == builtins { - types[key] = append(types[key], Builtins...) - } else if t == numbers { - types[key] = append(types[key], Numbers...) - } else { - types[key] = append(types[key], t) - } - } - } - - cursors := make(map[string]int) - for _, key := range keys { - cursors[key] = 0 - } - - outChan := make(chan map[string]string) - go func() { - buildTypeSet(keys, 0, cursors, types, outChan) - close(outChan) - }() - - var typeSets []map[string]string - for typeSet := range outChan { - typeSets = append(typeSets, typeSet) - } - - return typeSets, nil - -} - -func buildTypeSet(keys []string, keyI int, cursors map[string]int, types map[string][]string, out chan<- map[string]string) { - key := keys[keyI] - for cursors[key] < len(types[key]) { - if keyI < len(keys)-1 { - buildTypeSet(keys, keyI+1, copycursors(cursors), types, out) - } else { - // build the typeset for this combination - ts := make(map[string]string) - for k, vals := range types { - ts[k] = vals[cursors[k]] - } - out <- ts - } - cursors[key]++ - } -} - -func copycursors(source map[string]int) map[string]int { - copy := make(map[string]int) - for k, v := range source { - copy[k] = v - } - return copy -} diff --git a/vendor/github.com/golang/mock/AUTHORS b/vendor/github.com/golang/mock/AUTHORS new file mode 100644 index 00000000000..660b8ccc8ae --- /dev/null +++ b/vendor/github.com/golang/mock/AUTHORS @@ -0,0 +1,12 @@ +# This is the official list of GoMock authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. + +# Names should be added to this file as +# Name or Organization +# The email address is not required for organizations. + +# Please keep the list sorted. + +Alex Reece +Google Inc. diff --git a/vendor/github.com/golang/mock/CONTRIBUTORS b/vendor/github.com/golang/mock/CONTRIBUTORS new file mode 100644 index 00000000000..def849cab1b --- /dev/null +++ b/vendor/github.com/golang/mock/CONTRIBUTORS @@ -0,0 +1,37 @@ +# This is the official list of people who can contribute (and typically +# have contributed) code to the gomock repository. +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# The submission process automatically checks to make sure +# that people submitting code are listed in this file (by email address). +# +# Names should be added to this file only after verifying that +# the individual or the individual's organization has agreed to +# the appropriate Contributor License Agreement, found here: +# +# http://code.google.com/legal/individual-cla-v1.0.html +# http://code.google.com/legal/corporate-cla-v1.0.html +# +# The agreement for individuals can be filled out on the web. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file, depending on whether the +# individual or corporate CLA was used. + +# Names should be added to this file like so: +# Name +# +# An entry with two email addresses specifies that the +# first address should be used in the submit logs and +# that the second address should be recognized as the +# same person when interacting with Rietveld. + +# Please keep the list sorted. + +Aaron Jacobs +Alex Reece +David Symonds +Ryan Barrett diff --git a/vendor/github.com/golang/mock/LICENSE b/vendor/github.com/golang/mock/LICENSE new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/vendor/github.com/golang/mock/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/golang/mock/mockgen/mockgen.go b/vendor/github.com/golang/mock/mockgen/mockgen.go new file mode 100644 index 00000000000..50487070e3e --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/mockgen.go @@ -0,0 +1,701 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// MockGen generates mock implementations of Go interfaces. +package main + +// TODO: This does not support recursive embedded interfaces. +// TODO: This does not support embedding package-local interfaces in a separate file. + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "go/token" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "path" + "path/filepath" + "sort" + "strconv" + "strings" + "unicode" + + "github.com/golang/mock/mockgen/model" + + "golang.org/x/mod/modfile" + toolsimports "golang.org/x/tools/imports" +) + +const ( + gomockImportPath = "github.com/golang/mock/gomock" +) + +var ( + version = "" + commit = "none" + date = "unknown" +) + +var ( + source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") + destination = flag.String("destination", "", "Output file; defaults to stdout.") + mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") + packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") + selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") + writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") + copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") + + debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") + showVersion = flag.Bool("version", false, "Print version.") +) + +func main() { + flag.Usage = usage + flag.Parse() + + if *showVersion { + printVersion() + return + } + + var pkg *model.Package + var err error + var packageName string + if *source != "" { + pkg, err = sourceMode(*source) + } else { + if flag.NArg() != 2 { + usage() + log.Fatal("Expected exactly two arguments") + } + packageName = flag.Arg(0) + interfaces := strings.Split(flag.Arg(1), ",") + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + pkg, err = reflectMode(packageName, interfaces) + } + if err != nil { + log.Fatalf("Loading input failed: %v", err) + } + + if *debugParser { + pkg.Print(os.Stdout) + return + } + + dst := os.Stdout + if len(*destination) > 0 { + if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { + log.Fatalf("Unable to create directory: %v", err) + } + f, err := os.Create(*destination) + if err != nil { + log.Fatalf("Failed opening destination file: %v", err) + } + defer f.Close() + dst = f + } + + outputPackageName := *packageOut + if outputPackageName == "" { + // pkg.Name in reflect mode is the base name of the import path, + // which might have characters that are illegal to have in package names. + outputPackageName = "mock_" + sanitize(pkg.Name) + } + + // outputPackagePath represents the fully qualified name of the package of + // the generated code. Its purposes are to prevent the module from importing + // itself and to prevent qualifying type names that come from its own + // package (i.e. if there is a type called X then we want to print "X" not + // "package.X" since "package" is this package). This can happen if the mock + // is output into an already existing package. + outputPackagePath := *selfPackage + if outputPackagePath == "" && *destination != "" { + dstPath, err := filepath.Abs(filepath.Dir(*destination)) + if err == nil { + pkgPath, err := parsePackageImport(dstPath) + if err == nil { + outputPackagePath = pkgPath + } else { + log.Println("Unable to infer -self_package from destination file path:", err) + } + } else { + log.Println("Unable to determine destination file path:", err) + } + } + + g := new(generator) + if *source != "" { + g.filename = *source + } else { + g.srcPackage = packageName + g.srcInterfaces = flag.Arg(1) + } + g.destination = *destination + + if *mockNames != "" { + g.mockNames = parseMockNames(*mockNames) + } + if *copyrightFile != "" { + header, err := ioutil.ReadFile(*copyrightFile) + if err != nil { + log.Fatalf("Failed reading copyright file: %v", err) + } + + g.copyrightHeader = string(header) + } + if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { + log.Fatalf("Failed generating mock: %v", err) + } + if _, err := dst.Write(g.Output()); err != nil { + log.Fatalf("Failed writing to destination: %v", err) + } +} + +func parseMockNames(names string) map[string]string { + mocksMap := make(map[string]string) + for _, kv := range strings.Split(names, ",") { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 || parts[1] == "" { + log.Fatalf("bad mock names spec: %v", kv) + } + mocksMap[parts[0]] = parts[1] + } + return mocksMap +} + +func usage() { + _, _ = io.WriteString(os.Stderr, usageText) + flag.PrintDefaults() +} + +const usageText = `mockgen has two modes of operation: source and reflect. + +Source mode generates mock interfaces from a source file. +It is enabled by using the -source flag. Other flags that +may be useful in this mode are -imports and -aux_files. +Example: + mockgen -source=foo.go [other options] + +Reflect mode generates mock interfaces by building a program +that uses reflection to understand interfaces. It is enabled +by passing two non-flag arguments: an import path, and a +comma-separated list of symbols. +Example: + mockgen database/sql/driver Conn,Driver + +` + +type generator struct { + buf bytes.Buffer + indent string + mockNames map[string]string // may be empty + filename string // may be empty + destination string // may be empty + srcPackage, srcInterfaces string // may be empty + copyrightHeader string + + packageMap map[string]string // map from import path to package name +} + +func (g *generator) p(format string, args ...interface{}) { + fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) +} + +func (g *generator) in() { + g.indent += "\t" +} + +func (g *generator) out() { + if len(g.indent) > 0 { + g.indent = g.indent[0 : len(g.indent)-1] + } +} + +// sanitize cleans up a string to make a suitable package name. +func sanitize(s string) string { + t := "" + for _, r := range s { + if t == "" { + if unicode.IsLetter(r) || r == '_' { + t += string(r) + continue + } + } else { + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { + t += string(r) + continue + } + } + t += "_" + } + if t == "_" { + t = "x" + } + return t +} + +func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { + if outputPkgName != pkg.Name && *selfPackage == "" { + // reset outputPackagePath if it's not passed in through -self_package + outputPackagePath = "" + } + + if g.copyrightHeader != "" { + lines := strings.Split(g.copyrightHeader, "\n") + for _, line := range lines { + g.p("// %s", line) + } + g.p("") + } + + g.p("// Code generated by MockGen. DO NOT EDIT.") + if g.filename != "" { + g.p("// Source: %v", g.filename) + } else { + g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) + } + g.p("") + + // Get all required imports, and generate unique names for them all. + im := pkg.Imports() + im[gomockImportPath] = true + + // Only import reflect if it's used. We only use reflect in mocked methods + // so only import if any of the mocked interfaces have methods. + for _, intf := range pkg.Interfaces { + if len(intf.Methods) > 0 { + im["reflect"] = true + break + } + } + + // Sort keys to make import alias generation predictable + sortedPaths := make([]string, len(im)) + x := 0 + for pth := range im { + sortedPaths[x] = pth + x++ + } + sort.Strings(sortedPaths) + + packagesName := createPackageMap(sortedPaths) + + g.packageMap = make(map[string]string, len(im)) + localNames := make(map[string]bool, len(im)) + for _, pth := range sortedPaths { + base, ok := packagesName[pth] + if !ok { + base = sanitize(path.Base(pth)) + } + + // Local names for an imported package can usually be the basename of the import path. + // A couple of situations don't permit that, such as duplicate local names + // (e.g. importing "html/template" and "text/template"), or where the basename is + // a keyword (e.g. "foo/case"). + // try base0, base1, ... + pkgName := base + i := 0 + for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { + pkgName = base + strconv.Itoa(i) + i++ + } + + // Avoid importing package if source pkg == output pkg + if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { + continue + } + + g.packageMap[pth] = pkgName + localNames[pkgName] = true + } + + if *writePkgComment { + g.p("// Package %v is a generated GoMock package.", outputPkgName) + } + g.p("package %v", outputPkgName) + g.p("") + g.p("import (") + g.in() + for pkgPath, pkgName := range g.packageMap { + if pkgPath == outputPackagePath { + continue + } + g.p("%v %q", pkgName, pkgPath) + } + for _, pkgPath := range pkg.DotImports { + g.p(". %q", pkgPath) + } + g.out() + g.p(")") + + for _, intf := range pkg.Interfaces { + if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { + return err + } + } + + return nil +} + +// The name of the mock type to use for the given interface identifier. +func (g *generator) mockName(typeName string) string { + if mockName, ok := g.mockNames[typeName]; ok { + return mockName + } + + return "Mock" + typeName +} + +func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { + mockType := g.mockName(intf.Name) + + g.p("") + g.p("// %v is a mock of %v interface.", mockType, intf.Name) + g.p("type %v struct {", mockType) + g.in() + g.p("ctrl *gomock.Controller") + g.p("recorder *%vMockRecorder", mockType) + g.out() + g.p("}") + g.p("") + + g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) + g.p("type %vMockRecorder struct {", mockType) + g.in() + g.p("mock *%v", mockType) + g.out() + g.p("}") + g.p("") + + g.p("// New%v creates a new mock instance.", mockType) + g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) + g.in() + g.p("mock := &%v{ctrl: ctrl}", mockType) + g.p("mock.recorder = &%vMockRecorder{mock}", mockType) + g.p("return mock") + g.out() + g.p("}") + g.p("") + + // XXX: possible name collision here if someone has EXPECT in their interface. + g.p("// EXPECT returns an object that allows the caller to indicate expected use.") + g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) + g.in() + g.p("return m.recorder") + g.out() + g.p("}") + + g.GenerateMockMethods(mockType, intf, outputPackagePath) + + return nil +} + +type byMethodName []*model.Method + +func (b byMethodName) Len() int { return len(b) } +func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } + +func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { + sort.Sort(byMethodName(intf.Methods)) + for _, m := range intf.Methods { + g.p("") + _ = g.GenerateMockMethod(mockType, m, pkgOverride) + g.p("") + _ = g.GenerateMockRecorderMethod(mockType, m) + } +} + +func makeArgString(argNames, argTypes []string) string { + args := make([]string, len(argNames)) + for i, name := range argNames { + // specify the type only once for consecutive args of the same type + if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { + args[i] = name + } else { + args[i] = name + " " + argTypes[i] + } + } + return strings.Join(args, ", ") +} + +// GenerateMockMethod generates a mock method implementation. +// If non-empty, pkgOverride is the package in which unqualified types reside. +func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { + argNames := g.getArgNames(m) + argTypes := g.getArgTypes(m, pkgOverride) + argString := makeArgString(argNames, argTypes) + + rets := make([]string, len(m.Out)) + for i, p := range m.Out { + rets[i] = p.Type.String(g.packageMap, pkgOverride) + } + retString := strings.Join(rets, ", ") + if len(rets) > 1 { + retString = "(" + retString + ")" + } + if retString != "" { + retString = " " + retString + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("m") + + g.p("// %v mocks base method.", m.Name) + g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) + g.in() + g.p("%s.ctrl.T.Helper()", idRecv) + + var callArgs string + if m.Variadic == nil { + if len(argNames) > 0 { + callArgs = ", " + strings.Join(argNames, ", ") + } + } else { + // Non-trivial. The generated code must build a []interface{}, + // but the variadic argument may be any type. + idVarArgs := ia.allocateIdentifier("varargs") + idVArg := ia.allocateIdentifier("a") + g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) + g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) + g.in() + g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) + g.out() + g.p("}") + callArgs = ", " + idVarArgs + "..." + } + if len(m.Out) == 0 { + g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) + } else { + idRet := ia.allocateIdentifier("ret") + g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) + + // Go does not allow "naked" type assertions on nil values, so we use the two-value form here. + // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. + // Happily, this coincides with the semantics we want here. + retNames := make([]string, len(rets)) + for i, t := range rets { + retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) + g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) + } + g.p("return " + strings.Join(retNames, ", ")) + } + + g.out() + g.p("}") + return nil +} + +func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { + argNames := g.getArgNames(m) + + var argString string + if m.Variadic == nil { + argString = strings.Join(argNames, ", ") + } else { + argString = strings.Join(argNames[:len(argNames)-1], ", ") + } + if argString != "" { + argString += " interface{}" + } + + if m.Variadic != nil { + if argString != "" { + argString += ", " + } + argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) + } + + ia := newIdentifierAllocator(argNames) + idRecv := ia.allocateIdentifier("mr") + + g.p("// %v indicates an expected call of %v.", m.Name, m.Name) + g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) + g.in() + g.p("%s.mock.ctrl.T.Helper()", idRecv) + + var callArgs string + if m.Variadic == nil { + if len(argNames) > 0 { + callArgs = ", " + strings.Join(argNames, ", ") + } + } else { + if len(argNames) == 1 { + // Easy: just use ... to push the arguments through. + callArgs = ", " + argNames[0] + "..." + } else { + // Hard: create a temporary slice. + idVarArgs := ia.allocateIdentifier("varargs") + g.p("%s := append([]interface{}{%s}, %s...)", + idVarArgs, + strings.Join(argNames[:len(argNames)-1], ", "), + argNames[len(argNames)-1]) + callArgs = ", " + idVarArgs + "..." + } + } + g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) + + g.out() + g.p("}") + return nil +} + +func (g *generator) getArgNames(m *model.Method) []string { + argNames := make([]string, len(m.In)) + for i, p := range m.In { + name := p.Name + if name == "" || name == "_" { + name = fmt.Sprintf("arg%d", i) + } + argNames[i] = name + } + if m.Variadic != nil { + name := m.Variadic.Name + if name == "" { + name = fmt.Sprintf("arg%d", len(m.In)) + } + argNames = append(argNames, name) + } + return argNames +} + +func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { + argTypes := make([]string, len(m.In)) + for i, p := range m.In { + argTypes[i] = p.Type.String(g.packageMap, pkgOverride) + } + if m.Variadic != nil { + argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) + } + return argTypes +} + +type identifierAllocator map[string]struct{} + +func newIdentifierAllocator(taken []string) identifierAllocator { + a := make(identifierAllocator, len(taken)) + for _, s := range taken { + a[s] = struct{}{} + } + return a +} + +func (o identifierAllocator) allocateIdentifier(want string) string { + id := want + for i := 2; ; i++ { + if _, ok := o[id]; !ok { + o[id] = struct{}{} + return id + } + id = want + "_" + strconv.Itoa(i) + } +} + +// Output returns the generator's output, formatted in the standard Go style. +func (g *generator) Output() []byte { + src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil) + if err != nil { + log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) + } + return src +} + +// createPackageMap returns a map of import path to package name +// for specified importPaths. +func createPackageMap(importPaths []string) map[string]string { + var pkg struct { + Name string + ImportPath string + } + pkgMap := make(map[string]string) + b := bytes.NewBuffer(nil) + args := []string{"list", "-json"} + args = append(args, importPaths...) + cmd := exec.Command("go", args...) + cmd.Stdout = b + cmd.Run() + dec := json.NewDecoder(b) + for dec.More() { + err := dec.Decode(&pkg) + if err != nil { + log.Printf("failed to decode 'go list' output: %v", err) + continue + } + pkgMap[pkg.ImportPath] = pkg.Name + } + return pkgMap +} + +func printVersion() { + if version != "" { + fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date) + } else { + printModuleVersion() + } +} + +// parseImportPackage get package import path via source file +// an alternative implementation is to use: +// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} +// pkgs, err := packages.Load(cfg, "file="+source) +// However, it will call "go list" and slow down the performance +func parsePackageImport(srcDir string) (string, error) { + moduleMode := os.Getenv("GO111MODULE") + // trying to find the module + if moduleMode != "off" { + currentDir := srcDir + for { + dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) + if os.IsNotExist(err) { + if currentDir == filepath.Dir(currentDir) { + // at the root + break + } + currentDir = filepath.Dir(currentDir) + continue + } else if err != nil { + return "", err + } + modulePath := modfile.ModulePath(dat) + return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil + } + } + // fall back to GOPATH mode + goPaths := os.Getenv("GOPATH") + if goPaths == "" { + return "", fmt.Errorf("GOPATH is not set") + } + goPathList := strings.Split(goPaths, string(os.PathListSeparator)) + for _, goPath := range goPathList { + sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) + if strings.HasPrefix(srcDir, sourceRoot) { + return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil + } + } + return "", errOutsideGoPath +} diff --git a/vendor/github.com/golang/mock/mockgen/model/model.go b/vendor/github.com/golang/mock/mockgen/model/model.go new file mode 100644 index 00000000000..2c6a62ceb26 --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/model/model.go @@ -0,0 +1,495 @@ +// Copyright 2012 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package model contains the data model necessary for generating mock implementations. +package model + +import ( + "encoding/gob" + "fmt" + "io" + "reflect" + "strings" +) + +// pkgPath is the importable path for package model +const pkgPath = "github.com/golang/mock/mockgen/model" + +// Package is a Go package. It may be a subset. +type Package struct { + Name string + PkgPath string + Interfaces []*Interface + DotImports []string +} + +// Print writes the package name and its exported interfaces. +func (pkg *Package) Print(w io.Writer) { + _, _ = fmt.Fprintf(w, "package %s\n", pkg.Name) + for _, intf := range pkg.Interfaces { + intf.Print(w) + } +} + +// Imports returns the imports needed by the Package as a set of import paths. +func (pkg *Package) Imports() map[string]bool { + im := make(map[string]bool) + for _, intf := range pkg.Interfaces { + intf.addImports(im) + } + return im +} + +// Interface is a Go interface. +type Interface struct { + Name string + Methods []*Method +} + +// Print writes the interface name and its methods. +func (intf *Interface) Print(w io.Writer) { + _, _ = fmt.Fprintf(w, "interface %s\n", intf.Name) + for _, m := range intf.Methods { + m.Print(w) + } +} + +func (intf *Interface) addImports(im map[string]bool) { + for _, m := range intf.Methods { + m.addImports(im) + } +} + +// AddMethod adds a new method, de-duplicating by method name. +func (intf *Interface) AddMethod(m *Method) { + for _, me := range intf.Methods { + if me.Name == m.Name { + return + } + } + intf.Methods = append(intf.Methods, m) +} + +// Method is a single method of an interface. +type Method struct { + Name string + In, Out []*Parameter + Variadic *Parameter // may be nil +} + +// Print writes the method name and its signature. +func (m *Method) Print(w io.Writer) { + _, _ = fmt.Fprintf(w, " - method %s\n", m.Name) + if len(m.In) > 0 { + _, _ = fmt.Fprintf(w, " in:\n") + for _, p := range m.In { + p.Print(w) + } + } + if m.Variadic != nil { + _, _ = fmt.Fprintf(w, " ...:\n") + m.Variadic.Print(w) + } + if len(m.Out) > 0 { + _, _ = fmt.Fprintf(w, " out:\n") + for _, p := range m.Out { + p.Print(w) + } + } +} + +func (m *Method) addImports(im map[string]bool) { + for _, p := range m.In { + p.Type.addImports(im) + } + if m.Variadic != nil { + m.Variadic.Type.addImports(im) + } + for _, p := range m.Out { + p.Type.addImports(im) + } +} + +// Parameter is an argument or return parameter of a method. +type Parameter struct { + Name string // may be empty + Type Type +} + +// Print writes a method parameter. +func (p *Parameter) Print(w io.Writer) { + n := p.Name + if n == "" { + n = `""` + } + _, _ = fmt.Fprintf(w, " - %v: %v\n", n, p.Type.String(nil, "")) +} + +// Type is a Go type. +type Type interface { + String(pm map[string]string, pkgOverride string) string + addImports(im map[string]bool) +} + +func init() { + gob.Register(&ArrayType{}) + gob.Register(&ChanType{}) + gob.Register(&FuncType{}) + gob.Register(&MapType{}) + gob.Register(&NamedType{}) + gob.Register(&PointerType{}) + + // Call gob.RegisterName to make sure it has the consistent name registered + // for both gob decoder and encoder. + // + // For a non-pointer type, gob.Register will try to get package full path by + // calling rt.PkgPath() for a name to register. If your project has vendor + // directory, it is possible that PkgPath will get a path like this: + // ../../../vendor/github.com/golang/mock/mockgen/model + gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType("")) +} + +// ArrayType is an array or slice type. +type ArrayType struct { + Len int // -1 for slices, >= 0 for arrays + Type Type +} + +func (at *ArrayType) String(pm map[string]string, pkgOverride string) string { + s := "[]" + if at.Len > -1 { + s = fmt.Sprintf("[%d]", at.Len) + } + return s + at.Type.String(pm, pkgOverride) +} + +func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) } + +// ChanType is a channel type. +type ChanType struct { + Dir ChanDir // 0, 1 or 2 + Type Type +} + +func (ct *ChanType) String(pm map[string]string, pkgOverride string) string { + s := ct.Type.String(pm, pkgOverride) + if ct.Dir == RecvDir { + return "<-chan " + s + } + if ct.Dir == SendDir { + return "chan<- " + s + } + return "chan " + s +} + +func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) } + +// ChanDir is a channel direction. +type ChanDir int + +// Constants for channel directions. +const ( + RecvDir ChanDir = 1 + SendDir ChanDir = 2 +) + +// FuncType is a function type. +type FuncType struct { + In, Out []*Parameter + Variadic *Parameter // may be nil +} + +func (ft *FuncType) String(pm map[string]string, pkgOverride string) string { + args := make([]string, len(ft.In)) + for i, p := range ft.In { + args[i] = p.Type.String(pm, pkgOverride) + } + if ft.Variadic != nil { + args = append(args, "..."+ft.Variadic.Type.String(pm, pkgOverride)) + } + rets := make([]string, len(ft.Out)) + for i, p := range ft.Out { + rets[i] = p.Type.String(pm, pkgOverride) + } + retString := strings.Join(rets, ", ") + if nOut := len(ft.Out); nOut == 1 { + retString = " " + retString + } else if nOut > 1 { + retString = " (" + retString + ")" + } + return "func(" + strings.Join(args, ", ") + ")" + retString +} + +func (ft *FuncType) addImports(im map[string]bool) { + for _, p := range ft.In { + p.Type.addImports(im) + } + if ft.Variadic != nil { + ft.Variadic.Type.addImports(im) + } + for _, p := range ft.Out { + p.Type.addImports(im) + } +} + +// MapType is a map type. +type MapType struct { + Key, Value Type +} + +func (mt *MapType) String(pm map[string]string, pkgOverride string) string { + return "map[" + mt.Key.String(pm, pkgOverride) + "]" + mt.Value.String(pm, pkgOverride) +} + +func (mt *MapType) addImports(im map[string]bool) { + mt.Key.addImports(im) + mt.Value.addImports(im) +} + +// NamedType is an exported type in a package. +type NamedType struct { + Package string // may be empty + Type string +} + +func (nt *NamedType) String(pm map[string]string, pkgOverride string) string { + if pkgOverride == nt.Package { + return nt.Type + } + prefix := pm[nt.Package] + if prefix != "" { + return prefix + "." + nt.Type + } + + return nt.Type +} + +func (nt *NamedType) addImports(im map[string]bool) { + if nt.Package != "" { + im[nt.Package] = true + } +} + +// PointerType is a pointer to another type. +type PointerType struct { + Type Type +} + +func (pt *PointerType) String(pm map[string]string, pkgOverride string) string { + return "*" + pt.Type.String(pm, pkgOverride) +} +func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) } + +// PredeclaredType is a predeclared type such as "int". +type PredeclaredType string + +func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } +func (pt PredeclaredType) addImports(map[string]bool) {} + +// The following code is intended to be called by the program generated by ../reflect.go. + +// InterfaceFromInterfaceType returns a pointer to an interface for the +// given reflection interface type. +func InterfaceFromInterfaceType(it reflect.Type) (*Interface, error) { + if it.Kind() != reflect.Interface { + return nil, fmt.Errorf("%v is not an interface", it) + } + intf := &Interface{} + + for i := 0; i < it.NumMethod(); i++ { + mt := it.Method(i) + // TODO: need to skip unexported methods? or just raise an error? + m := &Method{ + Name: mt.Name, + } + + var err error + m.In, m.Variadic, m.Out, err = funcArgsFromType(mt.Type) + if err != nil { + return nil, err + } + + intf.AddMethod(m) + } + + return intf, nil +} + +// t's Kind must be a reflect.Func. +func funcArgsFromType(t reflect.Type) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) { + nin := t.NumIn() + if t.IsVariadic() { + nin-- + } + var p *Parameter + for i := 0; i < nin; i++ { + p, err = parameterFromType(t.In(i)) + if err != nil { + return + } + in = append(in, p) + } + if t.IsVariadic() { + p, err = parameterFromType(t.In(nin).Elem()) + if err != nil { + return + } + variadic = p + } + for i := 0; i < t.NumOut(); i++ { + p, err = parameterFromType(t.Out(i)) + if err != nil { + return + } + out = append(out, p) + } + return +} + +func parameterFromType(t reflect.Type) (*Parameter, error) { + tt, err := typeFromType(t) + if err != nil { + return nil, err + } + return &Parameter{Type: tt}, nil +} + +var errorType = reflect.TypeOf((*error)(nil)).Elem() + +var byteType = reflect.TypeOf(byte(0)) + +func typeFromType(t reflect.Type) (Type, error) { + // Hack workaround for https://golang.org/issue/3853. + // This explicit check should not be necessary. + if t == byteType { + return PredeclaredType("byte"), nil + } + + if imp := t.PkgPath(); imp != "" { + return &NamedType{ + Package: impPath(imp), + Type: t.Name(), + }, nil + } + + // only unnamed or predeclared types after here + + // Lots of types have element types. Let's do the parsing and error checking for all of them. + var elemType Type + switch t.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: + var err error + elemType, err = typeFromType(t.Elem()) + if err != nil { + return nil, err + } + } + + switch t.Kind() { + case reflect.Array: + return &ArrayType{ + Len: t.Len(), + Type: elemType, + }, nil + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String: + return PredeclaredType(t.Kind().String()), nil + case reflect.Chan: + var dir ChanDir + switch t.ChanDir() { + case reflect.RecvDir: + dir = RecvDir + case reflect.SendDir: + dir = SendDir + } + return &ChanType{ + Dir: dir, + Type: elemType, + }, nil + case reflect.Func: + in, variadic, out, err := funcArgsFromType(t) + if err != nil { + return nil, err + } + return &FuncType{ + In: in, + Out: out, + Variadic: variadic, + }, nil + case reflect.Interface: + // Two special interfaces. + if t.NumMethod() == 0 { + return PredeclaredType("interface{}"), nil + } + if t == errorType { + return PredeclaredType("error"), nil + } + case reflect.Map: + kt, err := typeFromType(t.Key()) + if err != nil { + return nil, err + } + return &MapType{ + Key: kt, + Value: elemType, + }, nil + case reflect.Ptr: + return &PointerType{ + Type: elemType, + }, nil + case reflect.Slice: + return &ArrayType{ + Len: -1, + Type: elemType, + }, nil + case reflect.Struct: + if t.NumField() == 0 { + return PredeclaredType("struct{}"), nil + } + } + + // TODO: Struct, UnsafePointer + return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind()) +} + +// impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that +// it is importable. PkgPath might return a path that includes "vendor". These paths do not +// compile, so we need to remove everything up to and including "/vendor/". +// See https://github.com/golang/go/issues/12019. +func impPath(imp string) string { + if strings.HasPrefix(imp, "vendor/") { + imp = "/" + imp + } + if i := strings.LastIndex(imp, "/vendor/"); i != -1 { + imp = imp[i+len("/vendor/"):] + } + return imp +} + +// ErrorInterface represent built-in error interface. +var ErrorInterface = Interface{ + Name: "error", + Methods: []*Method{ + { + Name: "Error", + Out: []*Parameter{ + { + Name: "", + Type: PredeclaredType("string"), + }, + }, + }, + }, +} diff --git a/vendor/github.com/golang/mock/mockgen/parse.go b/vendor/github.com/golang/mock/mockgen/parse.go new file mode 100644 index 00000000000..bf6902cd5b3 --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/parse.go @@ -0,0 +1,644 @@ +// Copyright 2012 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +// This file contains the model construction by parsing source files. + +import ( + "errors" + "flag" + "fmt" + "go/ast" + "go/build" + "go/importer" + "go/parser" + "go/token" + "go/types" + "io/ioutil" + "log" + "path" + "path/filepath" + "strconv" + "strings" + + "github.com/golang/mock/mockgen/model" +) + +var ( + imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") + auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") +) + +// sourceMode generates mocks via source file. +func sourceMode(source string) (*model.Package, error) { + srcDir, err := filepath.Abs(filepath.Dir(source)) + if err != nil { + return nil, fmt.Errorf("failed getting source directory: %v", err) + } + + packageImport, err := parsePackageImport(srcDir) + if err != nil { + return nil, err + } + + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, source, nil, 0) + if err != nil { + return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) + } + + p := &fileParser{ + fileSet: fs, + imports: make(map[string]importedPackage), + importedInterfaces: make(map[string]map[string]*ast.InterfaceType), + auxInterfaces: make(map[string]map[string]*ast.InterfaceType), + srcDir: srcDir, + } + + // Handle -imports. + dotImports := make(map[string]bool) + if *imports != "" { + for _, kv := range strings.Split(*imports, ",") { + eq := strings.Index(kv, "=") + k, v := kv[:eq], kv[eq+1:] + if k == "." { + dotImports[v] = true + } else { + p.imports[k] = importedPkg{path: v} + } + } + } + + // Handle -aux_files. + if err := p.parseAuxFiles(*auxFiles); err != nil { + return nil, err + } + p.addAuxInterfacesFromFile(packageImport, file) // this file + + pkg, err := p.parseFile(packageImport, file) + if err != nil { + return nil, err + } + for pkgPath := range dotImports { + pkg.DotImports = append(pkg.DotImports, pkgPath) + } + return pkg, nil +} + +type importedPackage interface { + Path() string + Parser() *fileParser +} + +type importedPkg struct { + path string + parser *fileParser +} + +func (i importedPkg) Path() string { return i.path } +func (i importedPkg) Parser() *fileParser { return i.parser } + +// duplicateImport is a bit of a misnomer. Currently the parser can't +// handle cases of multi-file packages importing different packages +// under the same name. Often these imports would not be problematic, +// so this type lets us defer raising an error unless the package name +// is actually used. +type duplicateImport struct { + name string + duplicates []string +} + +func (d duplicateImport) Error() string { + return fmt.Sprintf("%q is ambiguous because of duplicate imports: %v", d.name, d.duplicates) +} + +func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } +func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil } + +type fileParser struct { + fileSet *token.FileSet + imports map[string]importedPackage // package name => imported package + importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface + + auxFiles []*ast.File + auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface + + srcDir string +} + +func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { + ps := p.fileSet.Position(pos) + format = "%s:%d:%d: " + format + args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) + return fmt.Errorf(format, args...) +} + +func (p *fileParser) parseAuxFiles(auxFiles string) error { + auxFiles = strings.TrimSpace(auxFiles) + if auxFiles == "" { + return nil + } + for _, kv := range strings.Split(auxFiles, ",") { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return fmt.Errorf("bad aux file spec: %v", kv) + } + pkg, fpath := parts[0], parts[1] + + file, err := parser.ParseFile(p.fileSet, fpath, nil, 0) + if err != nil { + return err + } + p.auxFiles = append(p.auxFiles, file) + p.addAuxInterfacesFromFile(pkg, file) + } + return nil +} + +func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { + if _, ok := p.auxInterfaces[pkg]; !ok { + p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) + } + for ni := range iterInterfaces(file) { + p.auxInterfaces[pkg][ni.name.Name] = ni.it + } +} + +// parseFile loads all file imports and auxiliary files import into the +// fileParser, parses all file interfaces and returns package model. +func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) { + allImports, dotImports := importsOfFile(file) + // Don't stomp imports provided by -imports. Those should take precedence. + for pkg, pkgI := range allImports { + if _, ok := p.imports[pkg]; !ok { + p.imports[pkg] = pkgI + } + } + // Add imports from auxiliary files, which might be needed for embedded interfaces. + // Don't stomp any other imports. + for _, f := range p.auxFiles { + auxImports, _ := importsOfFile(f) + for pkg, pkgI := range auxImports { + if _, ok := p.imports[pkg]; !ok { + p.imports[pkg] = pkgI + } + } + } + + var is []*model.Interface + for ni := range iterInterfaces(file) { + i, err := p.parseInterface(ni.name.String(), importPath, ni.it) + if err != nil { + return nil, err + } + is = append(is, i) + } + return &model.Package{ + Name: file.Name.String(), + PkgPath: importPath, + Interfaces: is, + DotImports: dotImports, + }, nil +} + +// parsePackage loads package specified by path, parses it and returns +// a new fileParser with the parsed imports and interfaces. +func (p *fileParser) parsePackage(path string) (*fileParser, error) { + newP := &fileParser{ + fileSet: token.NewFileSet(), + imports: make(map[string]importedPackage), + importedInterfaces: make(map[string]map[string]*ast.InterfaceType), + auxInterfaces: make(map[string]map[string]*ast.InterfaceType), + srcDir: p.srcDir, + } + + var pkgs map[string]*ast.Package + if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil { + return nil, err + } else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil { + return nil, err + } + + for _, pkg := range pkgs { + file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) + if _, ok := newP.importedInterfaces[path]; !ok { + newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType) + } + for ni := range iterInterfaces(file) { + newP.importedInterfaces[path][ni.name.Name] = ni.it + } + imports, _ := importsOfFile(file) + for pkgName, pkgI := range imports { + newP.imports[pkgName] = pkgI + } + } + return newP, nil +} + +func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { + iface := &model.Interface{Name: name} + for _, field := range it.Methods.List { + switch v := field.Type.(type) { + case *ast.FuncType: + if nn := len(field.Names); nn != 1 { + return nil, fmt.Errorf("expected one name for interface %v, got %d", iface.Name, nn) + } + m := &model.Method{ + Name: field.Names[0].String(), + } + var err error + m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) + if err != nil { + return nil, err + } + iface.AddMethod(m) + case *ast.Ident: + // Embedded interface in this package. + embeddedIfaceType := p.auxInterfaces[pkg][v.String()] + if embeddedIfaceType == nil { + embeddedIfaceType = p.importedInterfaces[pkg][v.String()] + } + + var embeddedIface *model.Interface + if embeddedIfaceType != nil { + var err error + embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) + if err != nil { + return nil, err + } + } else { + // This is built-in error interface. + if v.String() == model.ErrorInterface.Name { + embeddedIface = &model.ErrorInterface + } else { + return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) + } + } + // Copy the methods. + for _, m := range embeddedIface.Methods { + iface.AddMethod(m) + } + case *ast.SelectorExpr: + // Embedded interface in another package. + filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() + embeddedPkg, ok := p.imports[filePkg] + if !ok { + return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) + } + + var embeddedIface *model.Interface + var err error + embeddedIfaceType := p.auxInterfaces[filePkg][sel] + if embeddedIfaceType != nil { + embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) + if err != nil { + return nil, err + } + } else { + path := embeddedPkg.Path() + parser := embeddedPkg.Parser() + if parser == nil { + ip, err := p.parsePackage(path) + if err != nil { + return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) + } + parser = ip + p.imports[filePkg] = importedPkg{ + path: embeddedPkg.Path(), + parser: parser, + } + } + if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil { + return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) + } + embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) + if err != nil { + return nil, err + } + } + // Copy the methods. + // TODO: apply shadowing rules. + for _, m := range embeddedIface.Methods { + iface.AddMethod(m) + } + default: + return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) + } + } + return iface, nil +} + +func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { + if f.Params != nil { + regParams := f.Params.List + if isVariadic(f) { + n := len(regParams) + varParams := regParams[n-1:] + regParams = regParams[:n-1] + vp, err := p.parseFieldList(pkg, varParams) + if err != nil { + return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) + } + variadic = vp[0] + } + inParam, err = p.parseFieldList(pkg, regParams) + if err != nil { + return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) + } + } + if f.Results != nil { + outParam, err = p.parseFieldList(pkg, f.Results.List) + if err != nil { + return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) + } + } + return +} + +func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { + nf := 0 + for _, f := range fields { + nn := len(f.Names) + if nn == 0 { + nn = 1 // anonymous parameter + } + nf += nn + } + if nf == 0 { + return nil, nil + } + ps := make([]*model.Parameter, nf) + i := 0 // destination index + for _, f := range fields { + t, err := p.parseType(pkg, f.Type) + if err != nil { + return nil, err + } + + if len(f.Names) == 0 { + // anonymous arg + ps[i] = &model.Parameter{Type: t} + i++ + continue + } + for _, name := range f.Names { + ps[i] = &model.Parameter{Name: name.Name, Type: t} + i++ + } + } + return ps, nil +} + +func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { + switch v := typ.(type) { + case *ast.ArrayType: + ln := -1 + if v.Len != nil { + var value string + switch val := v.Len.(type) { + case (*ast.BasicLit): + value = val.Value + case (*ast.Ident): + // when the length is a const defined locally + value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value + case (*ast.SelectorExpr): + // when the length is a const defined in an external package + usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X)) + if err != nil { + return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err) + } + ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name) + if err != nil { + return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err) + } + value = ev.Value.String() + } + + x, err := strconv.Atoi(value) + if err != nil { + return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) + } + ln = x + } + t, err := p.parseType(pkg, v.Elt) + if err != nil { + return nil, err + } + return &model.ArrayType{Len: ln, Type: t}, nil + case *ast.ChanType: + t, err := p.parseType(pkg, v.Value) + if err != nil { + return nil, err + } + var dir model.ChanDir + if v.Dir == ast.SEND { + dir = model.SendDir + } + if v.Dir == ast.RECV { + dir = model.RecvDir + } + return &model.ChanType{Dir: dir, Type: t}, nil + case *ast.Ellipsis: + // assume we're parsing a variadic argument + return p.parseType(pkg, v.Elt) + case *ast.FuncType: + in, variadic, out, err := p.parseFunc(pkg, v) + if err != nil { + return nil, err + } + return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil + case *ast.Ident: + if v.IsExported() { + // `pkg` may be an aliased imported pkg + // if so, patch the import w/ the fully qualified import + maybeImportedPkg, ok := p.imports[pkg] + if ok { + pkg = maybeImportedPkg.Path() + } + // assume type in this package + return &model.NamedType{Package: pkg, Type: v.Name}, nil + } + + // assume predeclared type + return model.PredeclaredType(v.Name), nil + case *ast.InterfaceType: + if v.Methods != nil && len(v.Methods.List) > 0 { + return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") + } + return model.PredeclaredType("interface{}"), nil + case *ast.MapType: + key, err := p.parseType(pkg, v.Key) + if err != nil { + return nil, err + } + value, err := p.parseType(pkg, v.Value) + if err != nil { + return nil, err + } + return &model.MapType{Key: key, Value: value}, nil + case *ast.SelectorExpr: + pkgName := v.X.(*ast.Ident).String() + pkg, ok := p.imports[pkgName] + if !ok { + return nil, p.errorf(v.Pos(), "unknown package %q", pkgName) + } + return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil + case *ast.StarExpr: + t, err := p.parseType(pkg, v.X) + if err != nil { + return nil, err + } + return &model.PointerType{Type: t}, nil + case *ast.StructType: + if v.Fields != nil && len(v.Fields.List) > 0 { + return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types") + } + return model.PredeclaredType("struct{}"), nil + case *ast.ParenExpr: + return p.parseType(pkg, v.X) + } + + return nil, fmt.Errorf("don't know how to parse type %T", typ) +} + +// importsOfFile returns a map of package name to import path +// of the imports in file. +func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) { + var importPaths []string + for _, is := range file.Imports { + if is.Name != nil { + continue + } + importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes + importPaths = append(importPaths, importPath) + } + packagesName := createPackageMap(importPaths) + normalImports = make(map[string]importedPackage) + dotImports = make([]string, 0) + for _, is := range file.Imports { + var pkgName string + importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes + + if is.Name != nil { + // Named imports are always certain. + if is.Name.Name == "_" { + continue + } + pkgName = is.Name.Name + } else { + pkg, ok := packagesName[importPath] + if !ok { + // Fallback to import path suffix. Note that this is uncertain. + _, last := path.Split(importPath) + // If the last path component has dots, the first dot-delimited + // field is used as the name. + pkgName = strings.SplitN(last, ".", 2)[0] + } else { + pkgName = pkg + } + } + + if pkgName == "." { + dotImports = append(dotImports, importPath) + } else { + if pkg, ok := normalImports[pkgName]; ok { + switch p := pkg.(type) { + case duplicateImport: + normalImports[pkgName] = duplicateImport{ + name: p.name, + duplicates: append([]string{importPath}, p.duplicates...), + } + case importedPkg: + normalImports[pkgName] = duplicateImport{ + name: pkgName, + duplicates: []string{p.path, importPath}, + } + } + } else { + normalImports[pkgName] = importedPkg{path: importPath} + } + } + } + return +} + +type namedInterface struct { + name *ast.Ident + it *ast.InterfaceType +} + +// Create an iterator over all interfaces in file. +func iterInterfaces(file *ast.File) <-chan namedInterface { + ch := make(chan namedInterface) + go func() { + for _, decl := range file.Decls { + gd, ok := decl.(*ast.GenDecl) + if !ok || gd.Tok != token.TYPE { + continue + } + for _, spec := range gd.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + it, ok := ts.Type.(*ast.InterfaceType) + if !ok { + continue + } + + ch <- namedInterface{ts.Name, it} + } + } + close(ch) + }() + return ch +} + +// isVariadic returns whether the function is variadic. +func isVariadic(f *ast.FuncType) bool { + nargs := len(f.Params.List) + if nargs == 0 { + return false + } + _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) + return ok +} + +// packageNameOfDir get package import path via dir +func packageNameOfDir(srcDir string) (string, error) { + files, err := ioutil.ReadDir(srcDir) + if err != nil { + log.Fatal(err) + } + + var goFilePath string + for _, file := range files { + if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { + goFilePath = file.Name() + break + } + } + if goFilePath == "" { + return "", fmt.Errorf("go source file not found %s", srcDir) + } + + packageImport, err := parsePackageImport(srcDir) + if err != nil { + return "", err + } + return packageImport, nil +} + +var errOutsideGoPath = errors.New("source directory is outside GOPATH") diff --git a/vendor/github.com/golang/mock/mockgen/reflect.go b/vendor/github.com/golang/mock/mockgen/reflect.go new file mode 100644 index 00000000000..e24efce0ba6 --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/reflect.go @@ -0,0 +1,256 @@ +// Copyright 2012 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +// This file contains the model construction by reflection. + +import ( + "bytes" + "encoding/gob" + "flag" + "fmt" + "go/build" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "text/template" + + "github.com/golang/mock/mockgen/model" +) + +var ( + progOnly = flag.Bool("prog_only", false, "(reflect mode) Only generate the reflection program; write it to stdout and exit.") + execOnly = flag.String("exec_only", "", "(reflect mode) If set, execute this reflection program.") + buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.") +) + +// reflectMode generates mocks via reflection on an interface. +func reflectMode(importPath string, symbols []string) (*model.Package, error) { + if *execOnly != "" { + return run(*execOnly) + } + + program, err := writeProgram(importPath, symbols) + if err != nil { + return nil, err + } + + if *progOnly { + if _, err := os.Stdout.Write(program); err != nil { + return nil, err + } + os.Exit(0) + } + + wd, _ := os.Getwd() + + // Try to run the reflection program in the current working directory. + if p, err := runInDir(program, wd); err == nil { + return p, nil + } + + // Try to run the program in the same directory as the input package. + if p, err := build.Import(importPath, wd, build.FindOnly); err == nil { + dir := p.Dir + if p, err := runInDir(program, dir); err == nil { + return p, nil + } + } + + // Try to run it in a standard temp directory. + return runInDir(program, "") +} + +func writeProgram(importPath string, symbols []string) ([]byte, error) { + var program bytes.Buffer + data := reflectData{ + ImportPath: importPath, + Symbols: symbols, + } + if err := reflectProgram.Execute(&program, &data); err != nil { + return nil, err + } + return program.Bytes(), nil +} + +// run the given program and parse the output as a model.Package. +func run(program string) (*model.Package, error) { + f, err := ioutil.TempFile("", "") + if err != nil { + return nil, err + } + + filename := f.Name() + defer os.Remove(filename) + if err := f.Close(); err != nil { + return nil, err + } + + // Run the program. + cmd := exec.Command(program, "-output", filename) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return nil, err + } + + f, err = os.Open(filename) + if err != nil { + return nil, err + } + + // Process output. + var pkg model.Package + if err := gob.NewDecoder(f).Decode(&pkg); err != nil { + return nil, err + } + + if err := f.Close(); err != nil { + return nil, err + } + + return &pkg, nil +} + +// runInDir writes the given program into the given dir, runs it there, and +// parses the output as a model.Package. +func runInDir(program []byte, dir string) (*model.Package, error) { + // We use TempDir instead of TempFile so we can control the filename. + tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_") + if err != nil { + return nil, err + } + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + log.Printf("failed to remove temp directory: %s", err) + } + }() + const progSource = "prog.go" + var progBinary = "prog.bin" + if runtime.GOOS == "windows" { + // Windows won't execute a program unless it has a ".exe" suffix. + progBinary += ".exe" + } + + if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { + return nil, err + } + + cmdArgs := []string{} + cmdArgs = append(cmdArgs, "build") + if *buildFlags != "" { + cmdArgs = append(cmdArgs, strings.Split(*buildFlags, " ")...) + } + cmdArgs = append(cmdArgs, "-o", progBinary, progSource) + + // Build the program. + buf := bytes.NewBuffer(nil) + cmd := exec.Command("go", cmdArgs...) + cmd.Dir = tmpDir + cmd.Stdout = os.Stdout + cmd.Stderr = io.MultiWriter(os.Stderr, buf) + if err := cmd.Run(); err != nil { + sErr := buf.String() + if strings.Contains(sErr, `cannot find package "."`) && + strings.Contains(sErr, "github.com/golang/mock/mockgen/model") { + fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.") + return nil, err + } + return nil, err + } + + return run(filepath.Join(tmpDir, progBinary)) +} + +type reflectData struct { + ImportPath string + Symbols []string +} + +// This program reflects on an interface value, and prints the +// gob encoding of a model.Package to standard output. +// JSON doesn't work because of the model.Type interface. +var reflectProgram = template.Must(template.New("program").Parse(` +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "os" + "path" + "reflect" + + "github.com/golang/mock/mockgen/model" + + pkg_ {{printf "%q" .ImportPath}} +) + +var output = flag.String("output", "", "The output file name, or empty to use stdout.") + +func main() { + flag.Parse() + + its := []struct{ + sym string + typ reflect.Type + }{ + {{range .Symbols}} + { {{printf "%q" .}}, reflect.TypeOf((*pkg_.{{.}})(nil)).Elem()}, + {{end}} + } + pkg := &model.Package{ + // NOTE: This behaves contrary to documented behaviour if the + // package name is not the final component of the import path. + // The reflect package doesn't expose the package name, though. + Name: path.Base({{printf "%q" .ImportPath}}), + } + + for _, it := range its { + intf, err := model.InterfaceFromInterfaceType(it.typ) + if err != nil { + fmt.Fprintf(os.Stderr, "Reflection: %v\n", err) + os.Exit(1) + } + intf.Name = it.sym + pkg.Interfaces = append(pkg.Interfaces, intf) + } + + outfile := os.Stdout + if len(*output) != 0 { + var err error + outfile, err = os.Create(*output) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open output file %q", *output) + } + defer func() { + if err := outfile.Close(); err != nil { + fmt.Fprintf(os.Stderr, "failed to close output file %q", *output) + os.Exit(1) + } + }() + } + + if err := gob.NewEncoder(outfile).Encode(pkg); err != nil { + fmt.Fprintf(os.Stderr, "gob encode: %v\n", err) + os.Exit(1) + } +} +`)) diff --git a/vendor/github.com/golang/mock/mockgen/version.1.11.go b/vendor/github.com/golang/mock/mockgen/version.1.11.go new file mode 100644 index 00000000000..e6b25db2384 --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/version.1.11.go @@ -0,0 +1,26 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !go1.12 + +package main + +import ( + "log" +) + +func printModuleVersion() { + log.Printf("No version information is available for Mockgen compiled with " + + "version 1.11") +} diff --git a/vendor/github.com/golang/mock/mockgen/version.1.12.go b/vendor/github.com/golang/mock/mockgen/version.1.12.go new file mode 100644 index 00000000000..ad121ae63c5 --- /dev/null +++ b/vendor/github.com/golang/mock/mockgen/version.1.12.go @@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// +build go1.12 + +package main + +import ( + "fmt" + "log" + "runtime/debug" +) + +func printModuleVersion() { + if bi, exists := debug.ReadBuildInfo(); exists { + fmt.Println(bi.Main.Version) + } else { + log.Printf("No version information found. Make sure to use " + + "GO111MODULE=on when running 'go get' in order to use specific " + + "version of the binary.") + } + +} diff --git a/vendor/github.com/google/pprof/AUTHORS b/vendor/github.com/google/pprof/AUTHORS new file mode 100644 index 00000000000..fd736cb1cfb --- /dev/null +++ b/vendor/github.com/google/pprof/AUTHORS @@ -0,0 +1,7 @@ +# This is the official list of pprof authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. +# Names should be added to this file as: +# Name or Organization +# The email address is not required for organizations. +Google Inc. \ No newline at end of file diff --git a/vendor/github.com/google/pprof/CONTRIBUTORS b/vendor/github.com/google/pprof/CONTRIBUTORS new file mode 100644 index 00000000000..8c8c37d2c8f --- /dev/null +++ b/vendor/github.com/google/pprof/CONTRIBUTORS @@ -0,0 +1,16 @@ +# People who have agreed to one of the CLAs and can contribute patches. +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# https://developers.google.com/open-source/cla/individual +# https://developers.google.com/open-source/cla/corporate +# +# Names should be added to this file as: +# Name +Raul Silvera +Tipp Moseley +Hyoun Kyu Cho +Martin Spier +Taco de Wolff +Andrew Hunter diff --git a/vendor/github.com/google/pprof/LICENSE b/vendor/github.com/google/pprof/LICENSE new file mode 100644 index 00000000000..d6456956733 --- /dev/null +++ b/vendor/github.com/google/pprof/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/google/pprof/profile/encode.go b/vendor/github.com/google/pprof/profile/encode.go new file mode 100644 index 00000000000..ab7f03ae267 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/encode.go @@ -0,0 +1,567 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile + +import ( + "errors" + "sort" +) + +func (p *Profile) decoder() []decoder { + return profileDecoder +} + +// preEncode populates the unexported fields to be used by encode +// (with suffix X) from the corresponding exported fields. The +// exported fields are cleared up to facilitate testing. +func (p *Profile) preEncode() { + strings := make(map[string]int) + addString(strings, "") + + for _, st := range p.SampleType { + st.typeX = addString(strings, st.Type) + st.unitX = addString(strings, st.Unit) + } + + for _, s := range p.Sample { + s.labelX = nil + var keys []string + for k := range s.Label { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + vs := s.Label[k] + for _, v := range vs { + s.labelX = append(s.labelX, + label{ + keyX: addString(strings, k), + strX: addString(strings, v), + }, + ) + } + } + var numKeys []string + for k := range s.NumLabel { + numKeys = append(numKeys, k) + } + sort.Strings(numKeys) + for _, k := range numKeys { + keyX := addString(strings, k) + vs := s.NumLabel[k] + units := s.NumUnit[k] + for i, v := range vs { + var unitX int64 + if len(units) != 0 { + unitX = addString(strings, units[i]) + } + s.labelX = append(s.labelX, + label{ + keyX: keyX, + numX: v, + unitX: unitX, + }, + ) + } + } + s.locationIDX = make([]uint64, len(s.Location)) + for i, loc := range s.Location { + s.locationIDX[i] = loc.ID + } + } + + for _, m := range p.Mapping { + m.fileX = addString(strings, m.File) + m.buildIDX = addString(strings, m.BuildID) + } + + for _, l := range p.Location { + for i, ln := range l.Line { + if ln.Function != nil { + l.Line[i].functionIDX = ln.Function.ID + } else { + l.Line[i].functionIDX = 0 + } + } + if l.Mapping != nil { + l.mappingIDX = l.Mapping.ID + } else { + l.mappingIDX = 0 + } + } + for _, f := range p.Function { + f.nameX = addString(strings, f.Name) + f.systemNameX = addString(strings, f.SystemName) + f.filenameX = addString(strings, f.Filename) + } + + p.dropFramesX = addString(strings, p.DropFrames) + p.keepFramesX = addString(strings, p.KeepFrames) + + if pt := p.PeriodType; pt != nil { + pt.typeX = addString(strings, pt.Type) + pt.unitX = addString(strings, pt.Unit) + } + + p.commentX = nil + for _, c := range p.Comments { + p.commentX = append(p.commentX, addString(strings, c)) + } + + p.defaultSampleTypeX = addString(strings, p.DefaultSampleType) + + p.stringTable = make([]string, len(strings)) + for s, i := range strings { + p.stringTable[i] = s + } +} + +func (p *Profile) encode(b *buffer) { + for _, x := range p.SampleType { + encodeMessage(b, 1, x) + } + for _, x := range p.Sample { + encodeMessage(b, 2, x) + } + for _, x := range p.Mapping { + encodeMessage(b, 3, x) + } + for _, x := range p.Location { + encodeMessage(b, 4, x) + } + for _, x := range p.Function { + encodeMessage(b, 5, x) + } + encodeStrings(b, 6, p.stringTable) + encodeInt64Opt(b, 7, p.dropFramesX) + encodeInt64Opt(b, 8, p.keepFramesX) + encodeInt64Opt(b, 9, p.TimeNanos) + encodeInt64Opt(b, 10, p.DurationNanos) + if pt := p.PeriodType; pt != nil && (pt.typeX != 0 || pt.unitX != 0) { + encodeMessage(b, 11, p.PeriodType) + } + encodeInt64Opt(b, 12, p.Period) + encodeInt64s(b, 13, p.commentX) + encodeInt64(b, 14, p.defaultSampleTypeX) +} + +var profileDecoder = []decoder{ + nil, // 0 + // repeated ValueType sample_type = 1 + func(b *buffer, m message) error { + x := new(ValueType) + pp := m.(*Profile) + pp.SampleType = append(pp.SampleType, x) + return decodeMessage(b, x) + }, + // repeated Sample sample = 2 + func(b *buffer, m message) error { + x := new(Sample) + pp := m.(*Profile) + pp.Sample = append(pp.Sample, x) + return decodeMessage(b, x) + }, + // repeated Mapping mapping = 3 + func(b *buffer, m message) error { + x := new(Mapping) + pp := m.(*Profile) + pp.Mapping = append(pp.Mapping, x) + return decodeMessage(b, x) + }, + // repeated Location location = 4 + func(b *buffer, m message) error { + x := new(Location) + x.Line = make([]Line, 0, 8) // Pre-allocate Line buffer + pp := m.(*Profile) + pp.Location = append(pp.Location, x) + err := decodeMessage(b, x) + var tmp []Line + x.Line = append(tmp, x.Line...) // Shrink to allocated size + return err + }, + // repeated Function function = 5 + func(b *buffer, m message) error { + x := new(Function) + pp := m.(*Profile) + pp.Function = append(pp.Function, x) + return decodeMessage(b, x) + }, + // repeated string string_table = 6 + func(b *buffer, m message) error { + err := decodeStrings(b, &m.(*Profile).stringTable) + if err != nil { + return err + } + if m.(*Profile).stringTable[0] != "" { + return errors.New("string_table[0] must be ''") + } + return nil + }, + // int64 drop_frames = 7 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).dropFramesX) }, + // int64 keep_frames = 8 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).keepFramesX) }, + // int64 time_nanos = 9 + func(b *buffer, m message) error { + if m.(*Profile).TimeNanos != 0 { + return errConcatProfile + } + return decodeInt64(b, &m.(*Profile).TimeNanos) + }, + // int64 duration_nanos = 10 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).DurationNanos) }, + // ValueType period_type = 11 + func(b *buffer, m message) error { + x := new(ValueType) + pp := m.(*Profile) + pp.PeriodType = x + return decodeMessage(b, x) + }, + // int64 period = 12 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).Period) }, + // repeated int64 comment = 13 + func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Profile).commentX) }, + // int64 defaultSampleType = 14 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).defaultSampleTypeX) }, +} + +// postDecode takes the unexported fields populated by decode (with +// suffix X) and populates the corresponding exported fields. +// The unexported fields are cleared up to facilitate testing. +func (p *Profile) postDecode() error { + var err error + mappings := make(map[uint64]*Mapping, len(p.Mapping)) + mappingIds := make([]*Mapping, len(p.Mapping)+1) + for _, m := range p.Mapping { + m.File, err = getString(p.stringTable, &m.fileX, err) + m.BuildID, err = getString(p.stringTable, &m.buildIDX, err) + if m.ID < uint64(len(mappingIds)) { + mappingIds[m.ID] = m + } else { + mappings[m.ID] = m + } + } + + functions := make(map[uint64]*Function, len(p.Function)) + functionIds := make([]*Function, len(p.Function)+1) + for _, f := range p.Function { + f.Name, err = getString(p.stringTable, &f.nameX, err) + f.SystemName, err = getString(p.stringTable, &f.systemNameX, err) + f.Filename, err = getString(p.stringTable, &f.filenameX, err) + if f.ID < uint64(len(functionIds)) { + functionIds[f.ID] = f + } else { + functions[f.ID] = f + } + } + + locations := make(map[uint64]*Location, len(p.Location)) + locationIds := make([]*Location, len(p.Location)+1) + for _, l := range p.Location { + if id := l.mappingIDX; id < uint64(len(mappingIds)) { + l.Mapping = mappingIds[id] + } else { + l.Mapping = mappings[id] + } + l.mappingIDX = 0 + for i, ln := range l.Line { + if id := ln.functionIDX; id != 0 { + l.Line[i].functionIDX = 0 + if id < uint64(len(functionIds)) { + l.Line[i].Function = functionIds[id] + } else { + l.Line[i].Function = functions[id] + } + } + } + if l.ID < uint64(len(locationIds)) { + locationIds[l.ID] = l + } else { + locations[l.ID] = l + } + } + + for _, st := range p.SampleType { + st.Type, err = getString(p.stringTable, &st.typeX, err) + st.Unit, err = getString(p.stringTable, &st.unitX, err) + } + + for _, s := range p.Sample { + labels := make(map[string][]string, len(s.labelX)) + numLabels := make(map[string][]int64, len(s.labelX)) + numUnits := make(map[string][]string, len(s.labelX)) + for _, l := range s.labelX { + var key, value string + key, err = getString(p.stringTable, &l.keyX, err) + if l.strX != 0 { + value, err = getString(p.stringTable, &l.strX, err) + labels[key] = append(labels[key], value) + } else if l.numX != 0 || l.unitX != 0 { + numValues := numLabels[key] + units := numUnits[key] + if l.unitX != 0 { + var unit string + unit, err = getString(p.stringTable, &l.unitX, err) + units = padStringArray(units, len(numValues)) + numUnits[key] = append(units, unit) + } + numLabels[key] = append(numLabels[key], l.numX) + } + } + if len(labels) > 0 { + s.Label = labels + } + if len(numLabels) > 0 { + s.NumLabel = numLabels + for key, units := range numUnits { + if len(units) > 0 { + numUnits[key] = padStringArray(units, len(numLabels[key])) + } + } + s.NumUnit = numUnits + } + s.Location = make([]*Location, len(s.locationIDX)) + for i, lid := range s.locationIDX { + if lid < uint64(len(locationIds)) { + s.Location[i] = locationIds[lid] + } else { + s.Location[i] = locations[lid] + } + } + s.locationIDX = nil + } + + p.DropFrames, err = getString(p.stringTable, &p.dropFramesX, err) + p.KeepFrames, err = getString(p.stringTable, &p.keepFramesX, err) + + if pt := p.PeriodType; pt == nil { + p.PeriodType = &ValueType{} + } + + if pt := p.PeriodType; pt != nil { + pt.Type, err = getString(p.stringTable, &pt.typeX, err) + pt.Unit, err = getString(p.stringTable, &pt.unitX, err) + } + + for _, i := range p.commentX { + var c string + c, err = getString(p.stringTable, &i, err) + p.Comments = append(p.Comments, c) + } + + p.commentX = nil + p.DefaultSampleType, err = getString(p.stringTable, &p.defaultSampleTypeX, err) + p.stringTable = nil + return err +} + +// padStringArray pads arr with enough empty strings to make arr +// length l when arr's length is less than l. +func padStringArray(arr []string, l int) []string { + if l <= len(arr) { + return arr + } + return append(arr, make([]string, l-len(arr))...) +} + +func (p *ValueType) decoder() []decoder { + return valueTypeDecoder +} + +func (p *ValueType) encode(b *buffer) { + encodeInt64Opt(b, 1, p.typeX) + encodeInt64Opt(b, 2, p.unitX) +} + +var valueTypeDecoder = []decoder{ + nil, // 0 + // optional int64 type = 1 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).typeX) }, + // optional int64 unit = 2 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).unitX) }, +} + +func (p *Sample) decoder() []decoder { + return sampleDecoder +} + +func (p *Sample) encode(b *buffer) { + encodeUint64s(b, 1, p.locationIDX) + encodeInt64s(b, 2, p.Value) + for _, x := range p.labelX { + encodeMessage(b, 3, x) + } +} + +var sampleDecoder = []decoder{ + nil, // 0 + // repeated uint64 location = 1 + func(b *buffer, m message) error { return decodeUint64s(b, &m.(*Sample).locationIDX) }, + // repeated int64 value = 2 + func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Sample).Value) }, + // repeated Label label = 3 + func(b *buffer, m message) error { + s := m.(*Sample) + n := len(s.labelX) + s.labelX = append(s.labelX, label{}) + return decodeMessage(b, &s.labelX[n]) + }, +} + +func (p label) decoder() []decoder { + return labelDecoder +} + +func (p label) encode(b *buffer) { + encodeInt64Opt(b, 1, p.keyX) + encodeInt64Opt(b, 2, p.strX) + encodeInt64Opt(b, 3, p.numX) + encodeInt64Opt(b, 4, p.unitX) +} + +var labelDecoder = []decoder{ + nil, // 0 + // optional int64 key = 1 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).keyX) }, + // optional int64 str = 2 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).strX) }, + // optional int64 num = 3 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).numX) }, + // optional int64 num = 4 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*label).unitX) }, +} + +func (p *Mapping) decoder() []decoder { + return mappingDecoder +} + +func (p *Mapping) encode(b *buffer) { + encodeUint64Opt(b, 1, p.ID) + encodeUint64Opt(b, 2, p.Start) + encodeUint64Opt(b, 3, p.Limit) + encodeUint64Opt(b, 4, p.Offset) + encodeInt64Opt(b, 5, p.fileX) + encodeInt64Opt(b, 6, p.buildIDX) + encodeBoolOpt(b, 7, p.HasFunctions) + encodeBoolOpt(b, 8, p.HasFilenames) + encodeBoolOpt(b, 9, p.HasLineNumbers) + encodeBoolOpt(b, 10, p.HasInlineFrames) +} + +var mappingDecoder = []decoder{ + nil, // 0 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).ID) }, // optional uint64 id = 1 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Start) }, // optional uint64 memory_offset = 2 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Limit) }, // optional uint64 memory_limit = 3 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Offset) }, // optional uint64 file_offset = 4 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).fileX) }, // optional int64 filename = 5 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).buildIDX) }, // optional int64 build_id = 6 + func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFunctions) }, // optional bool has_functions = 7 + func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFilenames) }, // optional bool has_filenames = 8 + func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasLineNumbers) }, // optional bool has_line_numbers = 9 + func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasInlineFrames) }, // optional bool has_inline_frames = 10 +} + +func (p *Location) decoder() []decoder { + return locationDecoder +} + +func (p *Location) encode(b *buffer) { + encodeUint64Opt(b, 1, p.ID) + encodeUint64Opt(b, 2, p.mappingIDX) + encodeUint64Opt(b, 3, p.Address) + for i := range p.Line { + encodeMessage(b, 4, &p.Line[i]) + } + encodeBoolOpt(b, 5, p.IsFolded) +} + +var locationDecoder = []decoder{ + nil, // 0 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).ID) }, // optional uint64 id = 1; + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).mappingIDX) }, // optional uint64 mapping_id = 2; + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).Address) }, // optional uint64 address = 3; + func(b *buffer, m message) error { // repeated Line line = 4 + pp := m.(*Location) + n := len(pp.Line) + pp.Line = append(pp.Line, Line{}) + return decodeMessage(b, &pp.Line[n]) + }, + func(b *buffer, m message) error { return decodeBool(b, &m.(*Location).IsFolded) }, // optional bool is_folded = 5; +} + +func (p *Line) decoder() []decoder { + return lineDecoder +} + +func (p *Line) encode(b *buffer) { + encodeUint64Opt(b, 1, p.functionIDX) + encodeInt64Opt(b, 2, p.Line) +} + +var lineDecoder = []decoder{ + nil, // 0 + // optional uint64 function_id = 1 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Line).functionIDX) }, + // optional int64 line = 2 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Line).Line) }, +} + +func (p *Function) decoder() []decoder { + return functionDecoder +} + +func (p *Function) encode(b *buffer) { + encodeUint64Opt(b, 1, p.ID) + encodeInt64Opt(b, 2, p.nameX) + encodeInt64Opt(b, 3, p.systemNameX) + encodeInt64Opt(b, 4, p.filenameX) + encodeInt64Opt(b, 5, p.StartLine) +} + +var functionDecoder = []decoder{ + nil, // 0 + // optional uint64 id = 1 + func(b *buffer, m message) error { return decodeUint64(b, &m.(*Function).ID) }, + // optional int64 function_name = 2 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).nameX) }, + // optional int64 function_system_name = 3 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).systemNameX) }, + // repeated int64 filename = 4 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).filenameX) }, + // optional int64 start_line = 5 + func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).StartLine) }, +} + +func addString(strings map[string]int, s string) int64 { + i, ok := strings[s] + if !ok { + i = len(strings) + strings[s] = i + } + return int64(i) +} + +func getString(strings []string, strng *int64, err error) (string, error) { + if err != nil { + return "", err + } + s := int(*strng) + if s < 0 || s >= len(strings) { + return "", errMalformed + } + *strng = 0 + return strings[s], nil +} diff --git a/vendor/github.com/google/pprof/profile/filter.go b/vendor/github.com/google/pprof/profile/filter.go new file mode 100644 index 00000000000..ea8e66c68d2 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/filter.go @@ -0,0 +1,270 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile + +// Implements methods to filter samples from profiles. + +import "regexp" + +// FilterSamplesByName filters the samples in a profile and only keeps +// samples where at least one frame matches focus but none match ignore. +// Returns true is the corresponding regexp matched at least one sample. +func (p *Profile) FilterSamplesByName(focus, ignore, hide, show *regexp.Regexp) (fm, im, hm, hnm bool) { + focusOrIgnore := make(map[uint64]bool) + hidden := make(map[uint64]bool) + for _, l := range p.Location { + if ignore != nil && l.matchesName(ignore) { + im = true + focusOrIgnore[l.ID] = false + } else if focus == nil || l.matchesName(focus) { + fm = true + focusOrIgnore[l.ID] = true + } + + if hide != nil && l.matchesName(hide) { + hm = true + l.Line = l.unmatchedLines(hide) + if len(l.Line) == 0 { + hidden[l.ID] = true + } + } + if show != nil { + l.Line = l.matchedLines(show) + if len(l.Line) == 0 { + hidden[l.ID] = true + } else { + hnm = true + } + } + } + + s := make([]*Sample, 0, len(p.Sample)) + for _, sample := range p.Sample { + if focusedAndNotIgnored(sample.Location, focusOrIgnore) { + if len(hidden) > 0 { + var locs []*Location + for _, loc := range sample.Location { + if !hidden[loc.ID] { + locs = append(locs, loc) + } + } + if len(locs) == 0 { + // Remove sample with no locations (by not adding it to s). + continue + } + sample.Location = locs + } + s = append(s, sample) + } + } + p.Sample = s + + return +} + +// ShowFrom drops all stack frames above the highest matching frame and returns +// whether a match was found. If showFrom is nil it returns false and does not +// modify the profile. +// +// Example: consider a sample with frames [A, B, C, B], where A is the root. +// ShowFrom(nil) returns false and has frames [A, B, C, B]. +// ShowFrom(A) returns true and has frames [A, B, C, B]. +// ShowFrom(B) returns true and has frames [B, C, B]. +// ShowFrom(C) returns true and has frames [C, B]. +// ShowFrom(D) returns false and drops the sample because no frames remain. +func (p *Profile) ShowFrom(showFrom *regexp.Regexp) (matched bool) { + if showFrom == nil { + return false + } + // showFromLocs stores location IDs that matched ShowFrom. + showFromLocs := make(map[uint64]bool) + // Apply to locations. + for _, loc := range p.Location { + if filterShowFromLocation(loc, showFrom) { + showFromLocs[loc.ID] = true + matched = true + } + } + // For all samples, strip locations after the highest matching one. + s := make([]*Sample, 0, len(p.Sample)) + for _, sample := range p.Sample { + for i := len(sample.Location) - 1; i >= 0; i-- { + if showFromLocs[sample.Location[i].ID] { + sample.Location = sample.Location[:i+1] + s = append(s, sample) + break + } + } + } + p.Sample = s + return matched +} + +// filterShowFromLocation tests a showFrom regex against a location, removes +// lines after the last match and returns whether a match was found. If the +// mapping is matched, then all lines are kept. +func filterShowFromLocation(loc *Location, showFrom *regexp.Regexp) bool { + if m := loc.Mapping; m != nil && showFrom.MatchString(m.File) { + return true + } + if i := loc.lastMatchedLineIndex(showFrom); i >= 0 { + loc.Line = loc.Line[:i+1] + return true + } + return false +} + +// lastMatchedLineIndex returns the index of the last line that matches a regex, +// or -1 if no match is found. +func (loc *Location) lastMatchedLineIndex(re *regexp.Regexp) int { + for i := len(loc.Line) - 1; i >= 0; i-- { + if fn := loc.Line[i].Function; fn != nil { + if re.MatchString(fn.Name) || re.MatchString(fn.Filename) { + return i + } + } + } + return -1 +} + +// FilterTagsByName filters the tags in a profile and only keeps +// tags that match show and not hide. +func (p *Profile) FilterTagsByName(show, hide *regexp.Regexp) (sm, hm bool) { + matchRemove := func(name string) bool { + matchShow := show == nil || show.MatchString(name) + matchHide := hide != nil && hide.MatchString(name) + + if matchShow { + sm = true + } + if matchHide { + hm = true + } + return !matchShow || matchHide + } + for _, s := range p.Sample { + for lab := range s.Label { + if matchRemove(lab) { + delete(s.Label, lab) + } + } + for lab := range s.NumLabel { + if matchRemove(lab) { + delete(s.NumLabel, lab) + } + } + } + return +} + +// matchesName returns whether the location matches the regular +// expression. It checks any available function names, file names, and +// mapping object filename. +func (loc *Location) matchesName(re *regexp.Regexp) bool { + for _, ln := range loc.Line { + if fn := ln.Function; fn != nil { + if re.MatchString(fn.Name) || re.MatchString(fn.Filename) { + return true + } + } + } + if m := loc.Mapping; m != nil && re.MatchString(m.File) { + return true + } + return false +} + +// unmatchedLines returns the lines in the location that do not match +// the regular expression. +func (loc *Location) unmatchedLines(re *regexp.Regexp) []Line { + if m := loc.Mapping; m != nil && re.MatchString(m.File) { + return nil + } + var lines []Line + for _, ln := range loc.Line { + if fn := ln.Function; fn != nil { + if re.MatchString(fn.Name) || re.MatchString(fn.Filename) { + continue + } + } + lines = append(lines, ln) + } + return lines +} + +// matchedLines returns the lines in the location that match +// the regular expression. +func (loc *Location) matchedLines(re *regexp.Regexp) []Line { + if m := loc.Mapping; m != nil && re.MatchString(m.File) { + return loc.Line + } + var lines []Line + for _, ln := range loc.Line { + if fn := ln.Function; fn != nil { + if !re.MatchString(fn.Name) && !re.MatchString(fn.Filename) { + continue + } + } + lines = append(lines, ln) + } + return lines +} + +// focusedAndNotIgnored looks up a slice of ids against a map of +// focused/ignored locations. The map only contains locations that are +// explicitly focused or ignored. Returns whether there is at least +// one focused location but no ignored locations. +func focusedAndNotIgnored(locs []*Location, m map[uint64]bool) bool { + var f bool + for _, loc := range locs { + if focus, focusOrIgnore := m[loc.ID]; focusOrIgnore { + if focus { + // Found focused location. Must keep searching in case there + // is an ignored one as well. + f = true + } else { + // Found ignored location. Can return false right away. + return false + } + } + } + return f +} + +// TagMatch selects tags for filtering +type TagMatch func(s *Sample) bool + +// FilterSamplesByTag removes all samples from the profile, except +// those that match focus and do not match the ignore regular +// expression. +func (p *Profile) FilterSamplesByTag(focus, ignore TagMatch) (fm, im bool) { + samples := make([]*Sample, 0, len(p.Sample)) + for _, s := range p.Sample { + focused, ignored := true, false + if focus != nil { + focused = focus(s) + } + if ignore != nil { + ignored = ignore(s) + } + fm = fm || focused + im = im || ignored + if focused && !ignored { + samples = append(samples, s) + } + } + p.Sample = samples + return +} diff --git a/vendor/github.com/google/pprof/profile/index.go b/vendor/github.com/google/pprof/profile/index.go new file mode 100644 index 00000000000..bef1d60467c --- /dev/null +++ b/vendor/github.com/google/pprof/profile/index.go @@ -0,0 +1,64 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile + +import ( + "fmt" + "strconv" + "strings" +) + +// SampleIndexByName returns the appropriate index for a value of sample index. +// If numeric, it returns the number, otherwise it looks up the text in the +// profile sample types. +func (p *Profile) SampleIndexByName(sampleIndex string) (int, error) { + if sampleIndex == "" { + if dst := p.DefaultSampleType; dst != "" { + for i, t := range sampleTypes(p) { + if t == dst { + return i, nil + } + } + } + // By default select the last sample value + return len(p.SampleType) - 1, nil + } + if i, err := strconv.Atoi(sampleIndex); err == nil { + if i < 0 || i >= len(p.SampleType) { + return 0, fmt.Errorf("sample_index %s is outside the range [0..%d]", sampleIndex, len(p.SampleType)-1) + } + return i, nil + } + + // Remove the inuse_ prefix to support legacy pprof options + // "inuse_space" and "inuse_objects" for profiles containing types + // "space" and "objects". + noInuse := strings.TrimPrefix(sampleIndex, "inuse_") + for i, t := range p.SampleType { + if t.Type == sampleIndex || t.Type == noInuse { + return i, nil + } + } + + return 0, fmt.Errorf("sample_index %q must be one of: %v", sampleIndex, sampleTypes(p)) +} + +func sampleTypes(p *Profile) []string { + types := make([]string, len(p.SampleType)) + for i, t := range p.SampleType { + types[i] = t.Type + } + return types +} diff --git a/vendor/github.com/google/pprof/profile/legacy_java_profile.go b/vendor/github.com/google/pprof/profile/legacy_java_profile.go new file mode 100644 index 00000000000..91f45e53c6c --- /dev/null +++ b/vendor/github.com/google/pprof/profile/legacy_java_profile.go @@ -0,0 +1,315 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements parsers to convert java legacy profiles into +// the profile.proto format. + +package profile + +import ( + "bytes" + "fmt" + "io" + "path/filepath" + "regexp" + "strconv" + "strings" +) + +var ( + attributeRx = regexp.MustCompile(`([\w ]+)=([\w ]+)`) + javaSampleRx = regexp.MustCompile(` *(\d+) +(\d+) +@ +([ x0-9a-f]*)`) + javaLocationRx = regexp.MustCompile(`^\s*0x([[:xdigit:]]+)\s+(.*)\s*$`) + javaLocationFileLineRx = regexp.MustCompile(`^(.*)\s+\((.+):(-?[[:digit:]]+)\)$`) + javaLocationPathRx = regexp.MustCompile(`^(.*)\s+\((.*)\)$`) +) + +// javaCPUProfile returns a new Profile from profilez data. +// b is the profile bytes after the header, period is the profiling +// period, and parse is a function to parse 8-byte chunks from the +// profile in its native endianness. +func javaCPUProfile(b []byte, period int64, parse func(b []byte) (uint64, []byte)) (*Profile, error) { + p := &Profile{ + Period: period * 1000, + PeriodType: &ValueType{Type: "cpu", Unit: "nanoseconds"}, + SampleType: []*ValueType{{Type: "samples", Unit: "count"}, {Type: "cpu", Unit: "nanoseconds"}}, + } + var err error + var locs map[uint64]*Location + if b, locs, err = parseCPUSamples(b, parse, false, p); err != nil { + return nil, err + } + + if err = parseJavaLocations(b, locs, p); err != nil { + return nil, err + } + + // Strip out addresses for better merge. + if err = p.Aggregate(true, true, true, true, false); err != nil { + return nil, err + } + + return p, nil +} + +// parseJavaProfile returns a new profile from heapz or contentionz +// data. b is the profile bytes after the header. +func parseJavaProfile(b []byte) (*Profile, error) { + h := bytes.SplitAfterN(b, []byte("\n"), 2) + if len(h) < 2 { + return nil, errUnrecognized + } + + p := &Profile{ + PeriodType: &ValueType{}, + } + header := string(bytes.TrimSpace(h[0])) + + var err error + var pType string + switch header { + case "--- heapz 1 ---": + pType = "heap" + case "--- contentionz 1 ---": + pType = "contention" + default: + return nil, errUnrecognized + } + + if b, err = parseJavaHeader(pType, h[1], p); err != nil { + return nil, err + } + var locs map[uint64]*Location + if b, locs, err = parseJavaSamples(pType, b, p); err != nil { + return nil, err + } + if err = parseJavaLocations(b, locs, p); err != nil { + return nil, err + } + + // Strip out addresses for better merge. + if err = p.Aggregate(true, true, true, true, false); err != nil { + return nil, err + } + + return p, nil +} + +// parseJavaHeader parses the attribute section on a java profile and +// populates a profile. Returns the remainder of the buffer after all +// attributes. +func parseJavaHeader(pType string, b []byte, p *Profile) ([]byte, error) { + nextNewLine := bytes.IndexByte(b, byte('\n')) + for nextNewLine != -1 { + line := string(bytes.TrimSpace(b[0:nextNewLine])) + if line != "" { + h := attributeRx.FindStringSubmatch(line) + if h == nil { + // Not a valid attribute, exit. + return b, nil + } + + attribute, value := strings.TrimSpace(h[1]), strings.TrimSpace(h[2]) + var err error + switch pType + "/" + attribute { + case "heap/format", "cpu/format", "contention/format": + if value != "java" { + return nil, errUnrecognized + } + case "heap/resolution": + p.SampleType = []*ValueType{ + {Type: "inuse_objects", Unit: "count"}, + {Type: "inuse_space", Unit: value}, + } + case "contention/resolution": + p.SampleType = []*ValueType{ + {Type: "contentions", Unit: "count"}, + {Type: "delay", Unit: value}, + } + case "contention/sampling period": + p.PeriodType = &ValueType{ + Type: "contentions", Unit: "count", + } + if p.Period, err = strconv.ParseInt(value, 0, 64); err != nil { + return nil, fmt.Errorf("failed to parse attribute %s: %v", line, err) + } + case "contention/ms since reset": + millis, err := strconv.ParseInt(value, 0, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse attribute %s: %v", line, err) + } + p.DurationNanos = millis * 1000 * 1000 + default: + return nil, errUnrecognized + } + } + // Grab next line. + b = b[nextNewLine+1:] + nextNewLine = bytes.IndexByte(b, byte('\n')) + } + return b, nil +} + +// parseJavaSamples parses the samples from a java profile and +// populates the Samples in a profile. Returns the remainder of the +// buffer after the samples. +func parseJavaSamples(pType string, b []byte, p *Profile) ([]byte, map[uint64]*Location, error) { + nextNewLine := bytes.IndexByte(b, byte('\n')) + locs := make(map[uint64]*Location) + for nextNewLine != -1 { + line := string(bytes.TrimSpace(b[0:nextNewLine])) + if line != "" { + sample := javaSampleRx.FindStringSubmatch(line) + if sample == nil { + // Not a valid sample, exit. + return b, locs, nil + } + + // Java profiles have data/fields inverted compared to other + // profile types. + var err error + value1, value2, value3 := sample[2], sample[1], sample[3] + addrs, err := parseHexAddresses(value3) + if err != nil { + return nil, nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + + var sloc []*Location + for _, addr := range addrs { + loc := locs[addr] + if locs[addr] == nil { + loc = &Location{ + Address: addr, + } + p.Location = append(p.Location, loc) + locs[addr] = loc + } + sloc = append(sloc, loc) + } + s := &Sample{ + Value: make([]int64, 2), + Location: sloc, + } + + if s.Value[0], err = strconv.ParseInt(value1, 0, 64); err != nil { + return nil, nil, fmt.Errorf("parsing sample %s: %v", line, err) + } + if s.Value[1], err = strconv.ParseInt(value2, 0, 64); err != nil { + return nil, nil, fmt.Errorf("parsing sample %s: %v", line, err) + } + + switch pType { + case "heap": + const javaHeapzSamplingRate = 524288 // 512K + if s.Value[0] == 0 { + return nil, nil, fmt.Errorf("parsing sample %s: second value must be non-zero", line) + } + s.NumLabel = map[string][]int64{"bytes": {s.Value[1] / s.Value[0]}} + s.Value[0], s.Value[1] = scaleHeapSample(s.Value[0], s.Value[1], javaHeapzSamplingRate) + case "contention": + if period := p.Period; period != 0 { + s.Value[0] = s.Value[0] * p.Period + s.Value[1] = s.Value[1] * p.Period + } + } + p.Sample = append(p.Sample, s) + } + // Grab next line. + b = b[nextNewLine+1:] + nextNewLine = bytes.IndexByte(b, byte('\n')) + } + return b, locs, nil +} + +// parseJavaLocations parses the location information in a java +// profile and populates the Locations in a profile. It uses the +// location addresses from the profile as both the ID of each +// location. +func parseJavaLocations(b []byte, locs map[uint64]*Location, p *Profile) error { + r := bytes.NewBuffer(b) + fns := make(map[string]*Function) + for { + line, err := r.ReadString('\n') + if err != nil { + if err != io.EOF { + return err + } + if line == "" { + break + } + } + + if line = strings.TrimSpace(line); line == "" { + continue + } + + jloc := javaLocationRx.FindStringSubmatch(line) + if len(jloc) != 3 { + continue + } + addr, err := strconv.ParseUint(jloc[1], 16, 64) + if err != nil { + return fmt.Errorf("parsing sample %s: %v", line, err) + } + loc := locs[addr] + if loc == nil { + // Unused/unseen + continue + } + var lineFunc, lineFile string + var lineNo int64 + + if fileLine := javaLocationFileLineRx.FindStringSubmatch(jloc[2]); len(fileLine) == 4 { + // Found a line of the form: "function (file:line)" + lineFunc, lineFile = fileLine[1], fileLine[2] + if n, err := strconv.ParseInt(fileLine[3], 10, 64); err == nil && n > 0 { + lineNo = n + } + } else if filePath := javaLocationPathRx.FindStringSubmatch(jloc[2]); len(filePath) == 3 { + // If there's not a file:line, it's a shared library path. + // The path isn't interesting, so just give the .so. + lineFunc, lineFile = filePath[1], filepath.Base(filePath[2]) + } else if strings.Contains(jloc[2], "generated stub/JIT") { + lineFunc = "STUB" + } else { + // Treat whole line as the function name. This is used by the + // java agent for internal states such as "GC" or "VM". + lineFunc = jloc[2] + } + fn := fns[lineFunc] + + if fn == nil { + fn = &Function{ + Name: lineFunc, + SystemName: lineFunc, + Filename: lineFile, + } + fns[lineFunc] = fn + p.Function = append(p.Function, fn) + } + loc.Line = []Line{ + { + Function: fn, + Line: lineNo, + }, + } + loc.Address = 0 + } + + p.remapLocationIDs() + p.remapFunctionIDs() + p.remapMappingIDs() + + return nil +} diff --git a/vendor/github.com/google/pprof/profile/legacy_profile.go b/vendor/github.com/google/pprof/profile/legacy_profile.go new file mode 100644 index 00000000000..0c8f3bb5b71 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/legacy_profile.go @@ -0,0 +1,1225 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file implements parsers to convert legacy profiles into the +// profile.proto format. + +package profile + +import ( + "bufio" + "bytes" + "fmt" + "io" + "math" + "regexp" + "strconv" + "strings" +) + +var ( + countStartRE = regexp.MustCompile(`\A(\S+) profile: total \d+\z`) + countRE = regexp.MustCompile(`\A(\d+) @(( 0x[0-9a-f]+)+)\z`) + + heapHeaderRE = regexp.MustCompile(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] *@ *(heap[_a-z0-9]*)/?(\d*)`) + heapSampleRE = regexp.MustCompile(`(-?\d+): *(-?\d+) *\[ *(\d+): *(\d+) *] @([ x0-9a-f]*)`) + + contentionSampleRE = regexp.MustCompile(`(\d+) *(\d+) @([ x0-9a-f]*)`) + + hexNumberRE = regexp.MustCompile(`0x[0-9a-f]+`) + + growthHeaderRE = regexp.MustCompile(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] @ growthz?`) + + fragmentationHeaderRE = regexp.MustCompile(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] @ fragmentationz?`) + + threadzStartRE = regexp.MustCompile(`--- threadz \d+ ---`) + threadStartRE = regexp.MustCompile(`--- Thread ([[:xdigit:]]+) \(name: (.*)/(\d+)\) stack: ---`) + + // Regular expressions to parse process mappings. Support the format used by Linux /proc/.../maps and other tools. + // Recommended format: + // Start End object file name offset(optional) linker build id + // 0x40000-0x80000 /path/to/binary (@FF00) abc123456 + spaceDigits = `\s+[[:digit:]]+` + hexPair = `\s+[[:xdigit:]]+:[[:xdigit:]]+` + oSpace = `\s*` + // Capturing expressions. + cHex = `(?:0x)?([[:xdigit:]]+)` + cHexRange = `\s*` + cHex + `[\s-]?` + oSpace + cHex + `:?` + cSpaceString = `(?:\s+(\S+))?` + cSpaceHex = `(?:\s+([[:xdigit:]]+))?` + cSpaceAtOffset = `(?:\s+\(@([[:xdigit:]]+)\))?` + cPerm = `(?:\s+([-rwxp]+))?` + + procMapsRE = regexp.MustCompile(`^` + cHexRange + cPerm + cSpaceHex + hexPair + spaceDigits + cSpaceString) + briefMapsRE = regexp.MustCompile(`^` + cHexRange + cPerm + cSpaceString + cSpaceAtOffset + cSpaceHex) + + // Regular expression to parse log data, of the form: + // ... file:line] msg... + logInfoRE = regexp.MustCompile(`^[^\[\]]+:[0-9]+]\s`) +) + +func isSpaceOrComment(line string) bool { + trimmed := strings.TrimSpace(line) + return len(trimmed) == 0 || trimmed[0] == '#' +} + +// parseGoCount parses a Go count profile (e.g., threadcreate or +// goroutine) and returns a new Profile. +func parseGoCount(b []byte) (*Profile, error) { + s := bufio.NewScanner(bytes.NewBuffer(b)) + // Skip comments at the beginning of the file. + for s.Scan() && isSpaceOrComment(s.Text()) { + } + if err := s.Err(); err != nil { + return nil, err + } + m := countStartRE.FindStringSubmatch(s.Text()) + if m == nil { + return nil, errUnrecognized + } + profileType := m[1] + p := &Profile{ + PeriodType: &ValueType{Type: profileType, Unit: "count"}, + Period: 1, + SampleType: []*ValueType{{Type: profileType, Unit: "count"}}, + } + locations := make(map[uint64]*Location) + for s.Scan() { + line := s.Text() + if isSpaceOrComment(line) { + continue + } + if strings.HasPrefix(line, "---") { + break + } + m := countRE.FindStringSubmatch(line) + if m == nil { + return nil, errMalformed + } + n, err := strconv.ParseInt(m[1], 0, 64) + if err != nil { + return nil, errMalformed + } + fields := strings.Fields(m[2]) + locs := make([]*Location, 0, len(fields)) + for _, stk := range fields { + addr, err := strconv.ParseUint(stk, 0, 64) + if err != nil { + return nil, errMalformed + } + // Adjust all frames by -1 to land on top of the call instruction. + addr-- + loc := locations[addr] + if loc == nil { + loc = &Location{ + Address: addr, + } + locations[addr] = loc + p.Location = append(p.Location, loc) + } + locs = append(locs, loc) + } + p.Sample = append(p.Sample, &Sample{ + Location: locs, + Value: []int64{n}, + }) + } + if err := s.Err(); err != nil { + return nil, err + } + + if err := parseAdditionalSections(s, p); err != nil { + return nil, err + } + return p, nil +} + +// remapLocationIDs ensures there is a location for each address +// referenced by a sample, and remaps the samples to point to the new +// location ids. +func (p *Profile) remapLocationIDs() { + seen := make(map[*Location]bool, len(p.Location)) + var locs []*Location + + for _, s := range p.Sample { + for _, l := range s.Location { + if seen[l] { + continue + } + l.ID = uint64(len(locs) + 1) + locs = append(locs, l) + seen[l] = true + } + } + p.Location = locs +} + +func (p *Profile) remapFunctionIDs() { + seen := make(map[*Function]bool, len(p.Function)) + var fns []*Function + + for _, l := range p.Location { + for _, ln := range l.Line { + fn := ln.Function + if fn == nil || seen[fn] { + continue + } + fn.ID = uint64(len(fns) + 1) + fns = append(fns, fn) + seen[fn] = true + } + } + p.Function = fns +} + +// remapMappingIDs matches location addresses with existing mappings +// and updates them appropriately. This is O(N*M), if this ever shows +// up as a bottleneck, evaluate sorting the mappings and doing a +// binary search, which would make it O(N*log(M)). +func (p *Profile) remapMappingIDs() { + // Some profile handlers will incorrectly set regions for the main + // executable if its section is remapped. Fix them through heuristics. + + if len(p.Mapping) > 0 { + // Remove the initial mapping if named '/anon_hugepage' and has a + // consecutive adjacent mapping. + if m := p.Mapping[0]; strings.HasPrefix(m.File, "/anon_hugepage") { + if len(p.Mapping) > 1 && m.Limit == p.Mapping[1].Start { + p.Mapping = p.Mapping[1:] + } + } + } + + // Subtract the offset from the start of the main mapping if it + // ends up at a recognizable start address. + if len(p.Mapping) > 0 { + const expectedStart = 0x400000 + if m := p.Mapping[0]; m.Start-m.Offset == expectedStart { + m.Start = expectedStart + m.Offset = 0 + } + } + + // Associate each location with an address to the corresponding + // mapping. Create fake mapping if a suitable one isn't found. + var fake *Mapping +nextLocation: + for _, l := range p.Location { + a := l.Address + if l.Mapping != nil || a == 0 { + continue + } + for _, m := range p.Mapping { + if m.Start <= a && a < m.Limit { + l.Mapping = m + continue nextLocation + } + } + // Work around legacy handlers failing to encode the first + // part of mappings split into adjacent ranges. + for _, m := range p.Mapping { + if m.Offset != 0 && m.Start-m.Offset <= a && a < m.Start { + m.Start -= m.Offset + m.Offset = 0 + l.Mapping = m + continue nextLocation + } + } + // If there is still no mapping, create a fake one. + // This is important for the Go legacy handler, which produced + // no mappings. + if fake == nil { + fake = &Mapping{ + ID: 1, + Limit: ^uint64(0), + } + p.Mapping = append(p.Mapping, fake) + } + l.Mapping = fake + } + + // Reset all mapping IDs. + for i, m := range p.Mapping { + m.ID = uint64(i + 1) + } +} + +var cpuInts = []func([]byte) (uint64, []byte){ + get32l, + get32b, + get64l, + get64b, +} + +func get32l(b []byte) (uint64, []byte) { + if len(b) < 4 { + return 0, nil + } + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24, b[4:] +} + +func get32b(b []byte) (uint64, []byte) { + if len(b) < 4 { + return 0, nil + } + return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(b[0])<<24, b[4:] +} + +func get64l(b []byte) (uint64, []byte) { + if len(b) < 8 { + return 0, nil + } + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56, b[8:] +} + +func get64b(b []byte) (uint64, []byte) { + if len(b) < 8 { + return 0, nil + } + return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56, b[8:] +} + +// parseCPU parses a profilez legacy profile and returns a newly +// populated Profile. +// +// The general format for profilez samples is a sequence of words in +// binary format. The first words are a header with the following data: +// 1st word -- 0 +// 2nd word -- 3 +// 3rd word -- 0 if a c++ application, 1 if a java application. +// 4th word -- Sampling period (in microseconds). +// 5th word -- Padding. +func parseCPU(b []byte) (*Profile, error) { + var parse func([]byte) (uint64, []byte) + var n1, n2, n3, n4, n5 uint64 + for _, parse = range cpuInts { + var tmp []byte + n1, tmp = parse(b) + n2, tmp = parse(tmp) + n3, tmp = parse(tmp) + n4, tmp = parse(tmp) + n5, tmp = parse(tmp) + + if tmp != nil && n1 == 0 && n2 == 3 && n3 == 0 && n4 > 0 && n5 == 0 { + b = tmp + return cpuProfile(b, int64(n4), parse) + } + if tmp != nil && n1 == 0 && n2 == 3 && n3 == 1 && n4 > 0 && n5 == 0 { + b = tmp + return javaCPUProfile(b, int64(n4), parse) + } + } + return nil, errUnrecognized +} + +// cpuProfile returns a new Profile from C++ profilez data. +// b is the profile bytes after the header, period is the profiling +// period, and parse is a function to parse 8-byte chunks from the +// profile in its native endianness. +func cpuProfile(b []byte, period int64, parse func(b []byte) (uint64, []byte)) (*Profile, error) { + p := &Profile{ + Period: period * 1000, + PeriodType: &ValueType{Type: "cpu", Unit: "nanoseconds"}, + SampleType: []*ValueType{ + {Type: "samples", Unit: "count"}, + {Type: "cpu", Unit: "nanoseconds"}, + }, + } + var err error + if b, _, err = parseCPUSamples(b, parse, true, p); err != nil { + return nil, err + } + + // If *most* samples have the same second-to-the-bottom frame, it + // strongly suggests that it is an uninteresting artifact of + // measurement -- a stack frame pushed by the signal handler. The + // bottom frame is always correct as it is picked up from the signal + // structure, not the stack. Check if this is the case and if so, + // remove. + + // Remove up to two frames. + maxiter := 2 + // Allow one different sample for this many samples with the same + // second-to-last frame. + similarSamples := 32 + margin := len(p.Sample) / similarSamples + + for iter := 0; iter < maxiter; iter++ { + addr1 := make(map[uint64]int) + for _, s := range p.Sample { + if len(s.Location) > 1 { + a := s.Location[1].Address + addr1[a] = addr1[a] + 1 + } + } + + for id1, count := range addr1 { + if count >= len(p.Sample)-margin { + // Found uninteresting frame, strip it out from all samples + for _, s := range p.Sample { + if len(s.Location) > 1 && s.Location[1].Address == id1 { + s.Location = append(s.Location[:1], s.Location[2:]...) + } + } + break + } + } + } + + if err := p.ParseMemoryMap(bytes.NewBuffer(b)); err != nil { + return nil, err + } + + cleanupDuplicateLocations(p) + return p, nil +} + +func cleanupDuplicateLocations(p *Profile) { + // The profile handler may duplicate the leaf frame, because it gets + // its address both from stack unwinding and from the signal + // context. Detect this and delete the duplicate, which has been + // adjusted by -1. The leaf address should not be adjusted as it is + // not a call. + for _, s := range p.Sample { + if len(s.Location) > 1 && s.Location[0].Address == s.Location[1].Address+1 { + s.Location = append(s.Location[:1], s.Location[2:]...) + } + } +} + +// parseCPUSamples parses a collection of profilez samples from a +// profile. +// +// profilez samples are a repeated sequence of stack frames of the +// form: +// 1st word -- The number of times this stack was encountered. +// 2nd word -- The size of the stack (StackSize). +// 3rd word -- The first address on the stack. +// ... +// StackSize + 2 -- The last address on the stack +// The last stack trace is of the form: +// 1st word -- 0 +// 2nd word -- 1 +// 3rd word -- 0 +// +// Addresses from stack traces may point to the next instruction after +// each call. Optionally adjust by -1 to land somewhere on the actual +// call (except for the leaf, which is not a call). +func parseCPUSamples(b []byte, parse func(b []byte) (uint64, []byte), adjust bool, p *Profile) ([]byte, map[uint64]*Location, error) { + locs := make(map[uint64]*Location) + for len(b) > 0 { + var count, nstk uint64 + count, b = parse(b) + nstk, b = parse(b) + if b == nil || nstk > uint64(len(b)/4) { + return nil, nil, errUnrecognized + } + var sloc []*Location + addrs := make([]uint64, nstk) + for i := 0; i < int(nstk); i++ { + addrs[i], b = parse(b) + } + + if count == 0 && nstk == 1 && addrs[0] == 0 { + // End of data marker + break + } + for i, addr := range addrs { + if adjust && i > 0 { + addr-- + } + loc := locs[addr] + if loc == nil { + loc = &Location{ + Address: addr, + } + locs[addr] = loc + p.Location = append(p.Location, loc) + } + sloc = append(sloc, loc) + } + p.Sample = append(p.Sample, + &Sample{ + Value: []int64{int64(count), int64(count) * p.Period}, + Location: sloc, + }) + } + // Reached the end without finding the EOD marker. + return b, locs, nil +} + +// parseHeap parses a heapz legacy or a growthz profile and +// returns a newly populated Profile. +func parseHeap(b []byte) (p *Profile, err error) { + s := bufio.NewScanner(bytes.NewBuffer(b)) + if !s.Scan() { + if err := s.Err(); err != nil { + return nil, err + } + return nil, errUnrecognized + } + p = &Profile{} + + sampling := "" + hasAlloc := false + + line := s.Text() + p.PeriodType = &ValueType{Type: "space", Unit: "bytes"} + if header := heapHeaderRE.FindStringSubmatch(line); header != nil { + sampling, p.Period, hasAlloc, err = parseHeapHeader(line) + if err != nil { + return nil, err + } + } else if header = growthHeaderRE.FindStringSubmatch(line); header != nil { + p.Period = 1 + } else if header = fragmentationHeaderRE.FindStringSubmatch(line); header != nil { + p.Period = 1 + } else { + return nil, errUnrecognized + } + + if hasAlloc { + // Put alloc before inuse so that default pprof selection + // will prefer inuse_space. + p.SampleType = []*ValueType{ + {Type: "alloc_objects", Unit: "count"}, + {Type: "alloc_space", Unit: "bytes"}, + {Type: "inuse_objects", Unit: "count"}, + {Type: "inuse_space", Unit: "bytes"}, + } + } else { + p.SampleType = []*ValueType{ + {Type: "objects", Unit: "count"}, + {Type: "space", Unit: "bytes"}, + } + } + + locs := make(map[uint64]*Location) + for s.Scan() { + line := strings.TrimSpace(s.Text()) + + if isSpaceOrComment(line) { + continue + } + + if isMemoryMapSentinel(line) { + break + } + + value, blocksize, addrs, err := parseHeapSample(line, p.Period, sampling, hasAlloc) + if err != nil { + return nil, err + } + + var sloc []*Location + for _, addr := range addrs { + // Addresses from stack traces point to the next instruction after + // each call. Adjust by -1 to land somewhere on the actual call. + addr-- + loc := locs[addr] + if locs[addr] == nil { + loc = &Location{ + Address: addr, + } + p.Location = append(p.Location, loc) + locs[addr] = loc + } + sloc = append(sloc, loc) + } + + p.Sample = append(p.Sample, &Sample{ + Value: value, + Location: sloc, + NumLabel: map[string][]int64{"bytes": {blocksize}}, + }) + } + if err := s.Err(); err != nil { + return nil, err + } + if err := parseAdditionalSections(s, p); err != nil { + return nil, err + } + return p, nil +} + +func parseHeapHeader(line string) (sampling string, period int64, hasAlloc bool, err error) { + header := heapHeaderRE.FindStringSubmatch(line) + if header == nil { + return "", 0, false, errUnrecognized + } + + if len(header[6]) > 0 { + if period, err = strconv.ParseInt(header[6], 10, 64); err != nil { + return "", 0, false, errUnrecognized + } + } + + if (header[3] != header[1] && header[3] != "0") || (header[4] != header[2] && header[4] != "0") { + hasAlloc = true + } + + switch header[5] { + case "heapz_v2", "heap_v2": + return "v2", period, hasAlloc, nil + case "heapprofile": + return "", 1, hasAlloc, nil + case "heap": + return "v2", period / 2, hasAlloc, nil + default: + return "", 0, false, errUnrecognized + } +} + +// parseHeapSample parses a single row from a heap profile into a new Sample. +func parseHeapSample(line string, rate int64, sampling string, includeAlloc bool) (value []int64, blocksize int64, addrs []uint64, err error) { + sampleData := heapSampleRE.FindStringSubmatch(line) + if len(sampleData) != 6 { + return nil, 0, nil, fmt.Errorf("unexpected number of sample values: got %d, want 6", len(sampleData)) + } + + // This is a local-scoped helper function to avoid needing to pass + // around rate, sampling and many return parameters. + addValues := func(countString, sizeString string, label string) error { + count, err := strconv.ParseInt(countString, 10, 64) + if err != nil { + return fmt.Errorf("malformed sample: %s: %v", line, err) + } + size, err := strconv.ParseInt(sizeString, 10, 64) + if err != nil { + return fmt.Errorf("malformed sample: %s: %v", line, err) + } + if count == 0 && size != 0 { + return fmt.Errorf("%s count was 0 but %s bytes was %d", label, label, size) + } + if count != 0 { + blocksize = size / count + if sampling == "v2" { + count, size = scaleHeapSample(count, size, rate) + } + } + value = append(value, count, size) + return nil + } + + if includeAlloc { + if err := addValues(sampleData[3], sampleData[4], "allocation"); err != nil { + return nil, 0, nil, err + } + } + + if err := addValues(sampleData[1], sampleData[2], "inuse"); err != nil { + return nil, 0, nil, err + } + + addrs, err = parseHexAddresses(sampleData[5]) + if err != nil { + return nil, 0, nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + + return value, blocksize, addrs, nil +} + +// parseHexAddresses extracts hex numbers from a string, attempts to convert +// each to an unsigned 64-bit number and returns the resulting numbers as a +// slice, or an error if the string contains hex numbers which are too large to +// handle (which means a malformed profile). +func parseHexAddresses(s string) ([]uint64, error) { + hexStrings := hexNumberRE.FindAllString(s, -1) + var addrs []uint64 + for _, s := range hexStrings { + if addr, err := strconv.ParseUint(s, 0, 64); err == nil { + addrs = append(addrs, addr) + } else { + return nil, fmt.Errorf("failed to parse as hex 64-bit number: %s", s) + } + } + return addrs, nil +} + +// scaleHeapSample adjusts the data from a heapz Sample to +// account for its probability of appearing in the collected +// data. heapz profiles are a sampling of the memory allocations +// requests in a program. We estimate the unsampled value by dividing +// each collected sample by its probability of appearing in the +// profile. heapz v2 profiles rely on a poisson process to determine +// which samples to collect, based on the desired average collection +// rate R. The probability of a sample of size S to appear in that +// profile is 1-exp(-S/R). +func scaleHeapSample(count, size, rate int64) (int64, int64) { + if count == 0 || size == 0 { + return 0, 0 + } + + if rate <= 1 { + // if rate==1 all samples were collected so no adjustment is needed. + // if rate<1 treat as unknown and skip scaling. + return count, size + } + + avgSize := float64(size) / float64(count) + scale := 1 / (1 - math.Exp(-avgSize/float64(rate))) + + return int64(float64(count) * scale), int64(float64(size) * scale) +} + +// parseContention parses a mutex or contention profile. There are 2 cases: +// "--- contentionz " for legacy C++ profiles (and backwards compatibility) +// "--- mutex:" or "--- contention:" for profiles generated by the Go runtime. +func parseContention(b []byte) (*Profile, error) { + s := bufio.NewScanner(bytes.NewBuffer(b)) + if !s.Scan() { + if err := s.Err(); err != nil { + return nil, err + } + return nil, errUnrecognized + } + + switch l := s.Text(); { + case strings.HasPrefix(l, "--- contentionz "): + case strings.HasPrefix(l, "--- mutex:"): + case strings.HasPrefix(l, "--- contention:"): + default: + return nil, errUnrecognized + } + + p := &Profile{ + PeriodType: &ValueType{Type: "contentions", Unit: "count"}, + Period: 1, + SampleType: []*ValueType{ + {Type: "contentions", Unit: "count"}, + {Type: "delay", Unit: "nanoseconds"}, + }, + } + + var cpuHz int64 + // Parse text of the form "attribute = value" before the samples. + const delimiter = "=" + for s.Scan() { + line := s.Text() + if line = strings.TrimSpace(line); isSpaceOrComment(line) { + continue + } + if strings.HasPrefix(line, "---") { + break + } + attr := strings.SplitN(line, delimiter, 2) + if len(attr) != 2 { + break + } + key, val := strings.TrimSpace(attr[0]), strings.TrimSpace(attr[1]) + var err error + switch key { + case "cycles/second": + if cpuHz, err = strconv.ParseInt(val, 0, 64); err != nil { + return nil, errUnrecognized + } + case "sampling period": + if p.Period, err = strconv.ParseInt(val, 0, 64); err != nil { + return nil, errUnrecognized + } + case "ms since reset": + ms, err := strconv.ParseInt(val, 0, 64) + if err != nil { + return nil, errUnrecognized + } + p.DurationNanos = ms * 1000 * 1000 + case "format": + // CPP contentionz profiles don't have format. + return nil, errUnrecognized + case "resolution": + // CPP contentionz profiles don't have resolution. + return nil, errUnrecognized + case "discarded samples": + default: + return nil, errUnrecognized + } + } + if err := s.Err(); err != nil { + return nil, err + } + + locs := make(map[uint64]*Location) + for { + line := strings.TrimSpace(s.Text()) + if strings.HasPrefix(line, "---") { + break + } + if !isSpaceOrComment(line) { + value, addrs, err := parseContentionSample(line, p.Period, cpuHz) + if err != nil { + return nil, err + } + var sloc []*Location + for _, addr := range addrs { + // Addresses from stack traces point to the next instruction after + // each call. Adjust by -1 to land somewhere on the actual call. + addr-- + loc := locs[addr] + if locs[addr] == nil { + loc = &Location{ + Address: addr, + } + p.Location = append(p.Location, loc) + locs[addr] = loc + } + sloc = append(sloc, loc) + } + p.Sample = append(p.Sample, &Sample{ + Value: value, + Location: sloc, + }) + } + if !s.Scan() { + break + } + } + if err := s.Err(); err != nil { + return nil, err + } + + if err := parseAdditionalSections(s, p); err != nil { + return nil, err + } + + return p, nil +} + +// parseContentionSample parses a single row from a contention profile +// into a new Sample. +func parseContentionSample(line string, period, cpuHz int64) (value []int64, addrs []uint64, err error) { + sampleData := contentionSampleRE.FindStringSubmatch(line) + if sampleData == nil { + return nil, nil, errUnrecognized + } + + v1, err := strconv.ParseInt(sampleData[1], 10, 64) + if err != nil { + return nil, nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + v2, err := strconv.ParseInt(sampleData[2], 10, 64) + if err != nil { + return nil, nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + + // Unsample values if period and cpuHz are available. + // - Delays are scaled to cycles and then to nanoseconds. + // - Contentions are scaled to cycles. + if period > 0 { + if cpuHz > 0 { + cpuGHz := float64(cpuHz) / 1e9 + v1 = int64(float64(v1) * float64(period) / cpuGHz) + } + v2 = v2 * period + } + + value = []int64{v2, v1} + addrs, err = parseHexAddresses(sampleData[3]) + if err != nil { + return nil, nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + + return value, addrs, nil +} + +// parseThread parses a Threadz profile and returns a new Profile. +func parseThread(b []byte) (*Profile, error) { + s := bufio.NewScanner(bytes.NewBuffer(b)) + // Skip past comments and empty lines seeking a real header. + for s.Scan() && isSpaceOrComment(s.Text()) { + } + + line := s.Text() + if m := threadzStartRE.FindStringSubmatch(line); m != nil { + // Advance over initial comments until first stack trace. + for s.Scan() { + if line = s.Text(); isMemoryMapSentinel(line) || strings.HasPrefix(line, "-") { + break + } + } + } else if t := threadStartRE.FindStringSubmatch(line); len(t) != 4 { + return nil, errUnrecognized + } + + p := &Profile{ + SampleType: []*ValueType{{Type: "thread", Unit: "count"}}, + PeriodType: &ValueType{Type: "thread", Unit: "count"}, + Period: 1, + } + + locs := make(map[uint64]*Location) + // Recognize each thread and populate profile samples. + for !isMemoryMapSentinel(line) { + if strings.HasPrefix(line, "---- no stack trace for") { + line = "" + break + } + if t := threadStartRE.FindStringSubmatch(line); len(t) != 4 { + return nil, errUnrecognized + } + + var addrs []uint64 + var err error + line, addrs, err = parseThreadSample(s) + if err != nil { + return nil, err + } + if len(addrs) == 0 { + // We got a --same as previous threads--. Bump counters. + if len(p.Sample) > 0 { + s := p.Sample[len(p.Sample)-1] + s.Value[0]++ + } + continue + } + + var sloc []*Location + for i, addr := range addrs { + // Addresses from stack traces point to the next instruction after + // each call. Adjust by -1 to land somewhere on the actual call + // (except for the leaf, which is not a call). + if i > 0 { + addr-- + } + loc := locs[addr] + if locs[addr] == nil { + loc = &Location{ + Address: addr, + } + p.Location = append(p.Location, loc) + locs[addr] = loc + } + sloc = append(sloc, loc) + } + + p.Sample = append(p.Sample, &Sample{ + Value: []int64{1}, + Location: sloc, + }) + } + + if err := parseAdditionalSections(s, p); err != nil { + return nil, err + } + + cleanupDuplicateLocations(p) + return p, nil +} + +// parseThreadSample parses a symbolized or unsymbolized stack trace. +// Returns the first line after the traceback, the sample (or nil if +// it hits a 'same-as-previous' marker) and an error. +func parseThreadSample(s *bufio.Scanner) (nextl string, addrs []uint64, err error) { + var line string + sameAsPrevious := false + for s.Scan() { + line = strings.TrimSpace(s.Text()) + if line == "" { + continue + } + + if strings.HasPrefix(line, "---") { + break + } + if strings.Contains(line, "same as previous thread") { + sameAsPrevious = true + continue + } + + curAddrs, err := parseHexAddresses(line) + if err != nil { + return "", nil, fmt.Errorf("malformed sample: %s: %v", line, err) + } + addrs = append(addrs, curAddrs...) + } + if err := s.Err(); err != nil { + return "", nil, err + } + if sameAsPrevious { + return line, nil, nil + } + return line, addrs, nil +} + +// parseAdditionalSections parses any additional sections in the +// profile, ignoring any unrecognized sections. +func parseAdditionalSections(s *bufio.Scanner, p *Profile) error { + for !isMemoryMapSentinel(s.Text()) && s.Scan() { + } + if err := s.Err(); err != nil { + return err + } + return p.ParseMemoryMapFromScanner(s) +} + +// ParseProcMaps parses a memory map in the format of /proc/self/maps. +// ParseMemoryMap should be called after setting on a profile to +// associate locations to the corresponding mapping based on their +// address. +func ParseProcMaps(rd io.Reader) ([]*Mapping, error) { + s := bufio.NewScanner(rd) + return parseProcMapsFromScanner(s) +} + +func parseProcMapsFromScanner(s *bufio.Scanner) ([]*Mapping, error) { + var mapping []*Mapping + + var attrs []string + const delimiter = "=" + r := strings.NewReplacer() + for s.Scan() { + line := r.Replace(removeLoggingInfo(s.Text())) + m, err := parseMappingEntry(line) + if err != nil { + if err == errUnrecognized { + // Recognize assignments of the form: attr=value, and replace + // $attr with value on subsequent mappings. + if attr := strings.SplitN(line, delimiter, 2); len(attr) == 2 { + attrs = append(attrs, "$"+strings.TrimSpace(attr[0]), strings.TrimSpace(attr[1])) + r = strings.NewReplacer(attrs...) + } + // Ignore any unrecognized entries + continue + } + return nil, err + } + if m == nil { + continue + } + mapping = append(mapping, m) + } + if err := s.Err(); err != nil { + return nil, err + } + return mapping, nil +} + +// removeLoggingInfo detects and removes log prefix entries generated +// by the glog package. If no logging prefix is detected, the string +// is returned unmodified. +func removeLoggingInfo(line string) string { + if match := logInfoRE.FindStringIndex(line); match != nil { + return line[match[1]:] + } + return line +} + +// ParseMemoryMap parses a memory map in the format of +// /proc/self/maps, and overrides the mappings in the current profile. +// It renumbers the samples and locations in the profile correspondingly. +func (p *Profile) ParseMemoryMap(rd io.Reader) error { + return p.ParseMemoryMapFromScanner(bufio.NewScanner(rd)) +} + +// ParseMemoryMapFromScanner parses a memory map in the format of +// /proc/self/maps or a variety of legacy format, and overrides the +// mappings in the current profile. It renumbers the samples and +// locations in the profile correspondingly. +func (p *Profile) ParseMemoryMapFromScanner(s *bufio.Scanner) error { + mapping, err := parseProcMapsFromScanner(s) + if err != nil { + return err + } + p.Mapping = append(p.Mapping, mapping...) + p.massageMappings() + p.remapLocationIDs() + p.remapFunctionIDs() + p.remapMappingIDs() + return nil +} + +func parseMappingEntry(l string) (*Mapping, error) { + var start, end, perm, file, offset, buildID string + if me := procMapsRE.FindStringSubmatch(l); len(me) == 6 { + start, end, perm, offset, file = me[1], me[2], me[3], me[4], me[5] + } else if me := briefMapsRE.FindStringSubmatch(l); len(me) == 7 { + start, end, perm, file, offset, buildID = me[1], me[2], me[3], me[4], me[5], me[6] + } else { + return nil, errUnrecognized + } + + var err error + mapping := &Mapping{ + File: file, + BuildID: buildID, + } + if perm != "" && !strings.Contains(perm, "x") { + // Skip non-executable entries. + return nil, nil + } + if mapping.Start, err = strconv.ParseUint(start, 16, 64); err != nil { + return nil, errUnrecognized + } + if mapping.Limit, err = strconv.ParseUint(end, 16, 64); err != nil { + return nil, errUnrecognized + } + if offset != "" { + if mapping.Offset, err = strconv.ParseUint(offset, 16, 64); err != nil { + return nil, errUnrecognized + } + } + return mapping, nil +} + +var memoryMapSentinels = []string{ + "--- Memory map: ---", + "MAPPED_LIBRARIES:", +} + +// isMemoryMapSentinel returns true if the string contains one of the +// known sentinels for memory map information. +func isMemoryMapSentinel(line string) bool { + for _, s := range memoryMapSentinels { + if strings.Contains(line, s) { + return true + } + } + return false +} + +func (p *Profile) addLegacyFrameInfo() { + switch { + case isProfileType(p, heapzSampleTypes): + p.DropFrames, p.KeepFrames = allocRxStr, allocSkipRxStr + case isProfileType(p, contentionzSampleTypes): + p.DropFrames, p.KeepFrames = lockRxStr, "" + default: + p.DropFrames, p.KeepFrames = cpuProfilerRxStr, "" + } +} + +var heapzSampleTypes = [][]string{ + {"allocations", "size"}, // early Go pprof profiles + {"objects", "space"}, + {"inuse_objects", "inuse_space"}, + {"alloc_objects", "alloc_space"}, + {"alloc_objects", "alloc_space", "inuse_objects", "inuse_space"}, // Go pprof legacy profiles +} +var contentionzSampleTypes = [][]string{ + {"contentions", "delay"}, +} + +func isProfileType(p *Profile, types [][]string) bool { + st := p.SampleType +nextType: + for _, t := range types { + if len(st) != len(t) { + continue + } + + for i := range st { + if st[i].Type != t[i] { + continue nextType + } + } + return true + } + return false +} + +var allocRxStr = strings.Join([]string{ + // POSIX entry points. + `calloc`, + `cfree`, + `malloc`, + `free`, + `memalign`, + `do_memalign`, + `(__)?posix_memalign`, + `pvalloc`, + `valloc`, + `realloc`, + + // TC malloc. + `tcmalloc::.*`, + `tc_calloc`, + `tc_cfree`, + `tc_malloc`, + `tc_free`, + `tc_memalign`, + `tc_posix_memalign`, + `tc_pvalloc`, + `tc_valloc`, + `tc_realloc`, + `tc_new`, + `tc_delete`, + `tc_newarray`, + `tc_deletearray`, + `tc_new_nothrow`, + `tc_newarray_nothrow`, + + // Memory-allocation routines on OS X. + `malloc_zone_malloc`, + `malloc_zone_calloc`, + `malloc_zone_valloc`, + `malloc_zone_realloc`, + `malloc_zone_memalign`, + `malloc_zone_free`, + + // Go runtime + `runtime\..*`, + + // Other misc. memory allocation routines + `BaseArena::.*`, + `(::)?do_malloc_no_errno`, + `(::)?do_malloc_pages`, + `(::)?do_malloc`, + `DoSampledAllocation`, + `MallocedMemBlock::MallocedMemBlock`, + `_M_allocate`, + `__builtin_(vec_)?delete`, + `__builtin_(vec_)?new`, + `__gnu_cxx::new_allocator::allocate`, + `__libc_malloc`, + `__malloc_alloc_template::allocate`, + `allocate`, + `cpp_alloc`, + `operator new(\[\])?`, + `simple_alloc::allocate`, +}, `|`) + +var allocSkipRxStr = strings.Join([]string{ + // Preserve Go runtime frames that appear in the middle/bottom of + // the stack. + `runtime\.panic`, + `runtime\.reflectcall`, + `runtime\.call[0-9]*`, +}, `|`) + +var cpuProfilerRxStr = strings.Join([]string{ + `ProfileData::Add`, + `ProfileData::prof_handler`, + `CpuProfiler::prof_handler`, + `__pthread_sighandler`, + `__restore`, +}, `|`) + +var lockRxStr = strings.Join([]string{ + `RecordLockProfileData`, + `(base::)?RecordLockProfileData.*`, + `(base::)?SubmitMutexProfileData.*`, + `(base::)?SubmitSpinLockProfileData.*`, + `(base::Mutex::)?AwaitCommon.*`, + `(base::Mutex::)?Unlock.*`, + `(base::Mutex::)?UnlockSlow.*`, + `(base::Mutex::)?ReaderUnlock.*`, + `(base::MutexLock::)?~MutexLock.*`, + `(Mutex::)?AwaitCommon.*`, + `(Mutex::)?Unlock.*`, + `(Mutex::)?UnlockSlow.*`, + `(Mutex::)?ReaderUnlock.*`, + `(MutexLock::)?~MutexLock.*`, + `(SpinLock::)?Unlock.*`, + `(SpinLock::)?SlowUnlock.*`, + `(SpinLockHolder::)?~SpinLockHolder.*`, +}, `|`) diff --git a/vendor/github.com/google/pprof/profile/merge.go b/vendor/github.com/google/pprof/profile/merge.go new file mode 100644 index 00000000000..9978e7330e6 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/merge.go @@ -0,0 +1,481 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +// Compact performs garbage collection on a profile to remove any +// unreferenced fields. This is useful to reduce the size of a profile +// after samples or locations have been removed. +func (p *Profile) Compact() *Profile { + p, _ = Merge([]*Profile{p}) + return p +} + +// Merge merges all the profiles in profs into a single Profile. +// Returns a new profile independent of the input profiles. The merged +// profile is compacted to eliminate unused samples, locations, +// functions and mappings. Profiles must have identical profile sample +// and period types or the merge will fail. profile.Period of the +// resulting profile will be the maximum of all profiles, and +// profile.TimeNanos will be the earliest nonzero one. Merges are +// associative with the caveat of the first profile having some +// specialization in how headers are combined. There may be other +// subtleties now or in the future regarding associativity. +func Merge(srcs []*Profile) (*Profile, error) { + if len(srcs) == 0 { + return nil, fmt.Errorf("no profiles to merge") + } + p, err := combineHeaders(srcs) + if err != nil { + return nil, err + } + + pm := &profileMerger{ + p: p, + samples: make(map[sampleKey]*Sample, len(srcs[0].Sample)), + locations: make(map[locationKey]*Location, len(srcs[0].Location)), + functions: make(map[functionKey]*Function, len(srcs[0].Function)), + mappings: make(map[mappingKey]*Mapping, len(srcs[0].Mapping)), + } + + for _, src := range srcs { + // Clear the profile-specific hash tables + pm.locationsByID = make(map[uint64]*Location, len(src.Location)) + pm.functionsByID = make(map[uint64]*Function, len(src.Function)) + pm.mappingsByID = make(map[uint64]mapInfo, len(src.Mapping)) + + if len(pm.mappings) == 0 && len(src.Mapping) > 0 { + // The Mapping list has the property that the first mapping + // represents the main binary. Take the first Mapping we see, + // otherwise the operations below will add mappings in an + // arbitrary order. + pm.mapMapping(src.Mapping[0]) + } + + for _, s := range src.Sample { + if !isZeroSample(s) { + pm.mapSample(s) + } + } + } + + for _, s := range p.Sample { + if isZeroSample(s) { + // If there are any zero samples, re-merge the profile to GC + // them. + return Merge([]*Profile{p}) + } + } + + return p, nil +} + +// Normalize normalizes the source profile by multiplying each value in profile by the +// ratio of the sum of the base profile's values of that sample type to the sum of the +// source profile's value of that sample type. +func (p *Profile) Normalize(pb *Profile) error { + + if err := p.compatible(pb); err != nil { + return err + } + + baseVals := make([]int64, len(p.SampleType)) + for _, s := range pb.Sample { + for i, v := range s.Value { + baseVals[i] += v + } + } + + srcVals := make([]int64, len(p.SampleType)) + for _, s := range p.Sample { + for i, v := range s.Value { + srcVals[i] += v + } + } + + normScale := make([]float64, len(baseVals)) + for i := range baseVals { + if srcVals[i] == 0 { + normScale[i] = 0.0 + } else { + normScale[i] = float64(baseVals[i]) / float64(srcVals[i]) + } + } + p.ScaleN(normScale) + return nil +} + +func isZeroSample(s *Sample) bool { + for _, v := range s.Value { + if v != 0 { + return false + } + } + return true +} + +type profileMerger struct { + p *Profile + + // Memoization tables within a profile. + locationsByID map[uint64]*Location + functionsByID map[uint64]*Function + mappingsByID map[uint64]mapInfo + + // Memoization tables for profile entities. + samples map[sampleKey]*Sample + locations map[locationKey]*Location + functions map[functionKey]*Function + mappings map[mappingKey]*Mapping +} + +type mapInfo struct { + m *Mapping + offset int64 +} + +func (pm *profileMerger) mapSample(src *Sample) *Sample { + s := &Sample{ + Location: make([]*Location, len(src.Location)), + Value: make([]int64, len(src.Value)), + Label: make(map[string][]string, len(src.Label)), + NumLabel: make(map[string][]int64, len(src.NumLabel)), + NumUnit: make(map[string][]string, len(src.NumLabel)), + } + for i, l := range src.Location { + s.Location[i] = pm.mapLocation(l) + } + for k, v := range src.Label { + vv := make([]string, len(v)) + copy(vv, v) + s.Label[k] = vv + } + for k, v := range src.NumLabel { + u := src.NumUnit[k] + vv := make([]int64, len(v)) + uu := make([]string, len(u)) + copy(vv, v) + copy(uu, u) + s.NumLabel[k] = vv + s.NumUnit[k] = uu + } + // Check memoization table. Must be done on the remapped location to + // account for the remapped mapping. Add current values to the + // existing sample. + k := s.key() + if ss, ok := pm.samples[k]; ok { + for i, v := range src.Value { + ss.Value[i] += v + } + return ss + } + copy(s.Value, src.Value) + pm.samples[k] = s + pm.p.Sample = append(pm.p.Sample, s) + return s +} + +// key generates sampleKey to be used as a key for maps. +func (sample *Sample) key() sampleKey { + ids := make([]string, len(sample.Location)) + for i, l := range sample.Location { + ids[i] = strconv.FormatUint(l.ID, 16) + } + + labels := make([]string, 0, len(sample.Label)) + for k, v := range sample.Label { + labels = append(labels, fmt.Sprintf("%q%q", k, v)) + } + sort.Strings(labels) + + numlabels := make([]string, 0, len(sample.NumLabel)) + for k, v := range sample.NumLabel { + numlabels = append(numlabels, fmt.Sprintf("%q%x%x", k, v, sample.NumUnit[k])) + } + sort.Strings(numlabels) + + return sampleKey{ + strings.Join(ids, "|"), + strings.Join(labels, ""), + strings.Join(numlabels, ""), + } +} + +type sampleKey struct { + locations string + labels string + numlabels string +} + +func (pm *profileMerger) mapLocation(src *Location) *Location { + if src == nil { + return nil + } + + if l, ok := pm.locationsByID[src.ID]; ok { + return l + } + + mi := pm.mapMapping(src.Mapping) + l := &Location{ + ID: uint64(len(pm.p.Location) + 1), + Mapping: mi.m, + Address: uint64(int64(src.Address) + mi.offset), + Line: make([]Line, len(src.Line)), + IsFolded: src.IsFolded, + } + for i, ln := range src.Line { + l.Line[i] = pm.mapLine(ln) + } + // Check memoization table. Must be done on the remapped location to + // account for the remapped mapping ID. + k := l.key() + if ll, ok := pm.locations[k]; ok { + pm.locationsByID[src.ID] = ll + return ll + } + pm.locationsByID[src.ID] = l + pm.locations[k] = l + pm.p.Location = append(pm.p.Location, l) + return l +} + +// key generates locationKey to be used as a key for maps. +func (l *Location) key() locationKey { + key := locationKey{ + addr: l.Address, + isFolded: l.IsFolded, + } + if l.Mapping != nil { + // Normalizes address to handle address space randomization. + key.addr -= l.Mapping.Start + key.mappingID = l.Mapping.ID + } + lines := make([]string, len(l.Line)*2) + for i, line := range l.Line { + if line.Function != nil { + lines[i*2] = strconv.FormatUint(line.Function.ID, 16) + } + lines[i*2+1] = strconv.FormatInt(line.Line, 16) + } + key.lines = strings.Join(lines, "|") + return key +} + +type locationKey struct { + addr, mappingID uint64 + lines string + isFolded bool +} + +func (pm *profileMerger) mapMapping(src *Mapping) mapInfo { + if src == nil { + return mapInfo{} + } + + if mi, ok := pm.mappingsByID[src.ID]; ok { + return mi + } + + // Check memoization tables. + mk := src.key() + if m, ok := pm.mappings[mk]; ok { + mi := mapInfo{m, int64(m.Start) - int64(src.Start)} + pm.mappingsByID[src.ID] = mi + return mi + } + m := &Mapping{ + ID: uint64(len(pm.p.Mapping) + 1), + Start: src.Start, + Limit: src.Limit, + Offset: src.Offset, + File: src.File, + BuildID: src.BuildID, + HasFunctions: src.HasFunctions, + HasFilenames: src.HasFilenames, + HasLineNumbers: src.HasLineNumbers, + HasInlineFrames: src.HasInlineFrames, + } + pm.p.Mapping = append(pm.p.Mapping, m) + + // Update memoization tables. + pm.mappings[mk] = m + mi := mapInfo{m, 0} + pm.mappingsByID[src.ID] = mi + return mi +} + +// key generates encoded strings of Mapping to be used as a key for +// maps. +func (m *Mapping) key() mappingKey { + // Normalize addresses to handle address space randomization. + // Round up to next 4K boundary to avoid minor discrepancies. + const mapsizeRounding = 0x1000 + + size := m.Limit - m.Start + size = size + mapsizeRounding - 1 + size = size - (size % mapsizeRounding) + key := mappingKey{ + size: size, + offset: m.Offset, + } + + switch { + case m.BuildID != "": + key.buildIDOrFile = m.BuildID + case m.File != "": + key.buildIDOrFile = m.File + default: + // A mapping containing neither build ID nor file name is a fake mapping. A + // key with empty buildIDOrFile is used for fake mappings so that they are + // treated as the same mapping during merging. + } + return key +} + +type mappingKey struct { + size, offset uint64 + buildIDOrFile string +} + +func (pm *profileMerger) mapLine(src Line) Line { + ln := Line{ + Function: pm.mapFunction(src.Function), + Line: src.Line, + } + return ln +} + +func (pm *profileMerger) mapFunction(src *Function) *Function { + if src == nil { + return nil + } + if f, ok := pm.functionsByID[src.ID]; ok { + return f + } + k := src.key() + if f, ok := pm.functions[k]; ok { + pm.functionsByID[src.ID] = f + return f + } + f := &Function{ + ID: uint64(len(pm.p.Function) + 1), + Name: src.Name, + SystemName: src.SystemName, + Filename: src.Filename, + StartLine: src.StartLine, + } + pm.functions[k] = f + pm.functionsByID[src.ID] = f + pm.p.Function = append(pm.p.Function, f) + return f +} + +// key generates a struct to be used as a key for maps. +func (f *Function) key() functionKey { + return functionKey{ + f.StartLine, + f.Name, + f.SystemName, + f.Filename, + } +} + +type functionKey struct { + startLine int64 + name, systemName, fileName string +} + +// combineHeaders checks that all profiles can be merged and returns +// their combined profile. +func combineHeaders(srcs []*Profile) (*Profile, error) { + for _, s := range srcs[1:] { + if err := srcs[0].compatible(s); err != nil { + return nil, err + } + } + + var timeNanos, durationNanos, period int64 + var comments []string + seenComments := map[string]bool{} + var defaultSampleType string + for _, s := range srcs { + if timeNanos == 0 || s.TimeNanos < timeNanos { + timeNanos = s.TimeNanos + } + durationNanos += s.DurationNanos + if period == 0 || period < s.Period { + period = s.Period + } + for _, c := range s.Comments { + if seen := seenComments[c]; !seen { + comments = append(comments, c) + seenComments[c] = true + } + } + if defaultSampleType == "" { + defaultSampleType = s.DefaultSampleType + } + } + + p := &Profile{ + SampleType: make([]*ValueType, len(srcs[0].SampleType)), + + DropFrames: srcs[0].DropFrames, + KeepFrames: srcs[0].KeepFrames, + + TimeNanos: timeNanos, + DurationNanos: durationNanos, + PeriodType: srcs[0].PeriodType, + Period: period, + + Comments: comments, + DefaultSampleType: defaultSampleType, + } + copy(p.SampleType, srcs[0].SampleType) + return p, nil +} + +// compatible determines if two profiles can be compared/merged. +// returns nil if the profiles are compatible; otherwise an error with +// details on the incompatibility. +func (p *Profile) compatible(pb *Profile) error { + if !equalValueType(p.PeriodType, pb.PeriodType) { + return fmt.Errorf("incompatible period types %v and %v", p.PeriodType, pb.PeriodType) + } + + if len(p.SampleType) != len(pb.SampleType) { + return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType) + } + + for i := range p.SampleType { + if !equalValueType(p.SampleType[i], pb.SampleType[i]) { + return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType) + } + } + return nil +} + +// equalValueType returns true if the two value types are semantically +// equal. It ignores the internal fields used during encode/decode. +func equalValueType(st1, st2 *ValueType) bool { + return st1.Type == st2.Type && st1.Unit == st2.Unit +} diff --git a/vendor/github.com/google/pprof/profile/profile.go b/vendor/github.com/google/pprof/profile/profile.go new file mode 100644 index 00000000000..2590c8ddb42 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/profile.go @@ -0,0 +1,805 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package profile provides a representation of profile.proto and +// methods to encode/decode profiles in this format. +package profile + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "math" + "path/filepath" + "regexp" + "sort" + "strings" + "sync" + "time" +) + +// Profile is an in-memory representation of profile.proto. +type Profile struct { + SampleType []*ValueType + DefaultSampleType string + Sample []*Sample + Mapping []*Mapping + Location []*Location + Function []*Function + Comments []string + + DropFrames string + KeepFrames string + + TimeNanos int64 + DurationNanos int64 + PeriodType *ValueType + Period int64 + + // The following fields are modified during encoding and copying, + // so are protected by a Mutex. + encodeMu sync.Mutex + + commentX []int64 + dropFramesX int64 + keepFramesX int64 + stringTable []string + defaultSampleTypeX int64 +} + +// ValueType corresponds to Profile.ValueType +type ValueType struct { + Type string // cpu, wall, inuse_space, etc + Unit string // seconds, nanoseconds, bytes, etc + + typeX int64 + unitX int64 +} + +// Sample corresponds to Profile.Sample +type Sample struct { + Location []*Location + Value []int64 + Label map[string][]string + NumLabel map[string][]int64 + NumUnit map[string][]string + + locationIDX []uint64 + labelX []label +} + +// label corresponds to Profile.Label +type label struct { + keyX int64 + // Exactly one of the two following values must be set + strX int64 + numX int64 // Integer value for this label + // can be set if numX has value + unitX int64 +} + +// Mapping corresponds to Profile.Mapping +type Mapping struct { + ID uint64 + Start uint64 + Limit uint64 + Offset uint64 + File string + BuildID string + HasFunctions bool + HasFilenames bool + HasLineNumbers bool + HasInlineFrames bool + + fileX int64 + buildIDX int64 +} + +// Location corresponds to Profile.Location +type Location struct { + ID uint64 + Mapping *Mapping + Address uint64 + Line []Line + IsFolded bool + + mappingIDX uint64 +} + +// Line corresponds to Profile.Line +type Line struct { + Function *Function + Line int64 + + functionIDX uint64 +} + +// Function corresponds to Profile.Function +type Function struct { + ID uint64 + Name string + SystemName string + Filename string + StartLine int64 + + nameX int64 + systemNameX int64 + filenameX int64 +} + +// Parse parses a profile and checks for its validity. The input +// may be a gzip-compressed encoded protobuf or one of many legacy +// profile formats which may be unsupported in the future. +func Parse(r io.Reader) (*Profile, error) { + data, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + return ParseData(data) +} + +// ParseData parses a profile from a buffer and checks for its +// validity. +func ParseData(data []byte) (*Profile, error) { + var p *Profile + var err error + if len(data) >= 2 && data[0] == 0x1f && data[1] == 0x8b { + gz, err := gzip.NewReader(bytes.NewBuffer(data)) + if err == nil { + data, err = ioutil.ReadAll(gz) + } + if err != nil { + return nil, fmt.Errorf("decompressing profile: %v", err) + } + } + if p, err = ParseUncompressed(data); err != nil && err != errNoData && err != errConcatProfile { + p, err = parseLegacy(data) + } + + if err != nil { + return nil, fmt.Errorf("parsing profile: %v", err) + } + + if err := p.CheckValid(); err != nil { + return nil, fmt.Errorf("malformed profile: %v", err) + } + return p, nil +} + +var errUnrecognized = fmt.Errorf("unrecognized profile format") +var errMalformed = fmt.Errorf("malformed profile format") +var errNoData = fmt.Errorf("empty input file") +var errConcatProfile = fmt.Errorf("concatenated profiles detected") + +func parseLegacy(data []byte) (*Profile, error) { + parsers := []func([]byte) (*Profile, error){ + parseCPU, + parseHeap, + parseGoCount, // goroutine, threadcreate + parseThread, + parseContention, + parseJavaProfile, + } + + for _, parser := range parsers { + p, err := parser(data) + if err == nil { + p.addLegacyFrameInfo() + return p, nil + } + if err != errUnrecognized { + return nil, err + } + } + return nil, errUnrecognized +} + +// ParseUncompressed parses an uncompressed protobuf into a profile. +func ParseUncompressed(data []byte) (*Profile, error) { + if len(data) == 0 { + return nil, errNoData + } + p := &Profile{} + if err := unmarshal(data, p); err != nil { + return nil, err + } + + if err := p.postDecode(); err != nil { + return nil, err + } + + return p, nil +} + +var libRx = regexp.MustCompile(`([.]so$|[.]so[._][0-9]+)`) + +// massageMappings applies heuristic-based changes to the profile +// mappings to account for quirks of some environments. +func (p *Profile) massageMappings() { + // Merge adjacent regions with matching names, checking that the offsets match + if len(p.Mapping) > 1 { + mappings := []*Mapping{p.Mapping[0]} + for _, m := range p.Mapping[1:] { + lm := mappings[len(mappings)-1] + if adjacent(lm, m) { + lm.Limit = m.Limit + if m.File != "" { + lm.File = m.File + } + if m.BuildID != "" { + lm.BuildID = m.BuildID + } + p.updateLocationMapping(m, lm) + continue + } + mappings = append(mappings, m) + } + p.Mapping = mappings + } + + // Use heuristics to identify main binary and move it to the top of the list of mappings + for i, m := range p.Mapping { + file := strings.TrimSpace(strings.Replace(m.File, "(deleted)", "", -1)) + if len(file) == 0 { + continue + } + if len(libRx.FindStringSubmatch(file)) > 0 { + continue + } + if file[0] == '[' { + continue + } + // Swap what we guess is main to position 0. + p.Mapping[0], p.Mapping[i] = p.Mapping[i], p.Mapping[0] + break + } + + // Keep the mapping IDs neatly sorted + for i, m := range p.Mapping { + m.ID = uint64(i + 1) + } +} + +// adjacent returns whether two mapping entries represent the same +// mapping that has been split into two. Check that their addresses are adjacent, +// and if the offsets match, if they are available. +func adjacent(m1, m2 *Mapping) bool { + if m1.File != "" && m2.File != "" { + if m1.File != m2.File { + return false + } + } + if m1.BuildID != "" && m2.BuildID != "" { + if m1.BuildID != m2.BuildID { + return false + } + } + if m1.Limit != m2.Start { + return false + } + if m1.Offset != 0 && m2.Offset != 0 { + offset := m1.Offset + (m1.Limit - m1.Start) + if offset != m2.Offset { + return false + } + } + return true +} + +func (p *Profile) updateLocationMapping(from, to *Mapping) { + for _, l := range p.Location { + if l.Mapping == from { + l.Mapping = to + } + } +} + +func serialize(p *Profile) []byte { + p.encodeMu.Lock() + p.preEncode() + b := marshal(p) + p.encodeMu.Unlock() + return b +} + +// Write writes the profile as a gzip-compressed marshaled protobuf. +func (p *Profile) Write(w io.Writer) error { + zw := gzip.NewWriter(w) + defer zw.Close() + _, err := zw.Write(serialize(p)) + return err +} + +// WriteUncompressed writes the profile as a marshaled protobuf. +func (p *Profile) WriteUncompressed(w io.Writer) error { + _, err := w.Write(serialize(p)) + return err +} + +// CheckValid tests whether the profile is valid. Checks include, but are +// not limited to: +// - len(Profile.Sample[n].value) == len(Profile.value_unit) +// - Sample.id has a corresponding Profile.Location +func (p *Profile) CheckValid() error { + // Check that sample values are consistent + sampleLen := len(p.SampleType) + if sampleLen == 0 && len(p.Sample) != 0 { + return fmt.Errorf("missing sample type information") + } + for _, s := range p.Sample { + if s == nil { + return fmt.Errorf("profile has nil sample") + } + if len(s.Value) != sampleLen { + return fmt.Errorf("mismatch: sample has %d values vs. %d types", len(s.Value), len(p.SampleType)) + } + for _, l := range s.Location { + if l == nil { + return fmt.Errorf("sample has nil location") + } + } + } + + // Check that all mappings/locations/functions are in the tables + // Check that there are no duplicate ids + mappings := make(map[uint64]*Mapping, len(p.Mapping)) + for _, m := range p.Mapping { + if m == nil { + return fmt.Errorf("profile has nil mapping") + } + if m.ID == 0 { + return fmt.Errorf("found mapping with reserved ID=0") + } + if mappings[m.ID] != nil { + return fmt.Errorf("multiple mappings with same id: %d", m.ID) + } + mappings[m.ID] = m + } + functions := make(map[uint64]*Function, len(p.Function)) + for _, f := range p.Function { + if f == nil { + return fmt.Errorf("profile has nil function") + } + if f.ID == 0 { + return fmt.Errorf("found function with reserved ID=0") + } + if functions[f.ID] != nil { + return fmt.Errorf("multiple functions with same id: %d", f.ID) + } + functions[f.ID] = f + } + locations := make(map[uint64]*Location, len(p.Location)) + for _, l := range p.Location { + if l == nil { + return fmt.Errorf("profile has nil location") + } + if l.ID == 0 { + return fmt.Errorf("found location with reserved id=0") + } + if locations[l.ID] != nil { + return fmt.Errorf("multiple locations with same id: %d", l.ID) + } + locations[l.ID] = l + if m := l.Mapping; m != nil { + if m.ID == 0 || mappings[m.ID] != m { + return fmt.Errorf("inconsistent mapping %p: %d", m, m.ID) + } + } + for _, ln := range l.Line { + f := ln.Function + if f == nil { + return fmt.Errorf("location id: %d has a line with nil function", l.ID) + } + if f.ID == 0 || functions[f.ID] != f { + return fmt.Errorf("inconsistent function %p: %d", f, f.ID) + } + } + } + return nil +} + +// Aggregate merges the locations in the profile into equivalence +// classes preserving the request attributes. It also updates the +// samples to point to the merged locations. +func (p *Profile) Aggregate(inlineFrame, function, filename, linenumber, address bool) error { + for _, m := range p.Mapping { + m.HasInlineFrames = m.HasInlineFrames && inlineFrame + m.HasFunctions = m.HasFunctions && function + m.HasFilenames = m.HasFilenames && filename + m.HasLineNumbers = m.HasLineNumbers && linenumber + } + + // Aggregate functions + if !function || !filename { + for _, f := range p.Function { + if !function { + f.Name = "" + f.SystemName = "" + } + if !filename { + f.Filename = "" + } + } + } + + // Aggregate locations + if !inlineFrame || !address || !linenumber { + for _, l := range p.Location { + if !inlineFrame && len(l.Line) > 1 { + l.Line = l.Line[len(l.Line)-1:] + } + if !linenumber { + for i := range l.Line { + l.Line[i].Line = 0 + } + } + if !address { + l.Address = 0 + } + } + } + + return p.CheckValid() +} + +// NumLabelUnits returns a map of numeric label keys to the units +// associated with those keys and a map of those keys to any units +// that were encountered but not used. +// Unit for a given key is the first encountered unit for that key. If multiple +// units are encountered for values paired with a particular key, then the first +// unit encountered is used and all other units are returned in sorted order +// in map of ignored units. +// If no units are encountered for a particular key, the unit is then inferred +// based on the key. +func (p *Profile) NumLabelUnits() (map[string]string, map[string][]string) { + numLabelUnits := map[string]string{} + ignoredUnits := map[string]map[string]bool{} + encounteredKeys := map[string]bool{} + + // Determine units based on numeric tags for each sample. + for _, s := range p.Sample { + for k := range s.NumLabel { + encounteredKeys[k] = true + for _, unit := range s.NumUnit[k] { + if unit == "" { + continue + } + if wantUnit, ok := numLabelUnits[k]; !ok { + numLabelUnits[k] = unit + } else if wantUnit != unit { + if v, ok := ignoredUnits[k]; ok { + v[unit] = true + } else { + ignoredUnits[k] = map[string]bool{unit: true} + } + } + } + } + } + // Infer units for keys without any units associated with + // numeric tag values. + for key := range encounteredKeys { + unit := numLabelUnits[key] + if unit == "" { + switch key { + case "alignment", "request": + numLabelUnits[key] = "bytes" + default: + numLabelUnits[key] = key + } + } + } + + // Copy ignored units into more readable format + unitsIgnored := make(map[string][]string, len(ignoredUnits)) + for key, values := range ignoredUnits { + units := make([]string, len(values)) + i := 0 + for unit := range values { + units[i] = unit + i++ + } + sort.Strings(units) + unitsIgnored[key] = units + } + + return numLabelUnits, unitsIgnored +} + +// String dumps a text representation of a profile. Intended mainly +// for debugging purposes. +func (p *Profile) String() string { + ss := make([]string, 0, len(p.Comments)+len(p.Sample)+len(p.Mapping)+len(p.Location)) + for _, c := range p.Comments { + ss = append(ss, "Comment: "+c) + } + if pt := p.PeriodType; pt != nil { + ss = append(ss, fmt.Sprintf("PeriodType: %s %s", pt.Type, pt.Unit)) + } + ss = append(ss, fmt.Sprintf("Period: %d", p.Period)) + if p.TimeNanos != 0 { + ss = append(ss, fmt.Sprintf("Time: %v", time.Unix(0, p.TimeNanos))) + } + if p.DurationNanos != 0 { + ss = append(ss, fmt.Sprintf("Duration: %.4v", time.Duration(p.DurationNanos))) + } + + ss = append(ss, "Samples:") + var sh1 string + for _, s := range p.SampleType { + dflt := "" + if s.Type == p.DefaultSampleType { + dflt = "[dflt]" + } + sh1 = sh1 + fmt.Sprintf("%s/%s%s ", s.Type, s.Unit, dflt) + } + ss = append(ss, strings.TrimSpace(sh1)) + for _, s := range p.Sample { + ss = append(ss, s.string()) + } + + ss = append(ss, "Locations") + for _, l := range p.Location { + ss = append(ss, l.string()) + } + + ss = append(ss, "Mappings") + for _, m := range p.Mapping { + ss = append(ss, m.string()) + } + + return strings.Join(ss, "\n") + "\n" +} + +// string dumps a text representation of a mapping. Intended mainly +// for debugging purposes. +func (m *Mapping) string() string { + bits := "" + if m.HasFunctions { + bits = bits + "[FN]" + } + if m.HasFilenames { + bits = bits + "[FL]" + } + if m.HasLineNumbers { + bits = bits + "[LN]" + } + if m.HasInlineFrames { + bits = bits + "[IN]" + } + return fmt.Sprintf("%d: %#x/%#x/%#x %s %s %s", + m.ID, + m.Start, m.Limit, m.Offset, + m.File, + m.BuildID, + bits) +} + +// string dumps a text representation of a location. Intended mainly +// for debugging purposes. +func (l *Location) string() string { + ss := []string{} + locStr := fmt.Sprintf("%6d: %#x ", l.ID, l.Address) + if m := l.Mapping; m != nil { + locStr = locStr + fmt.Sprintf("M=%d ", m.ID) + } + if l.IsFolded { + locStr = locStr + "[F] " + } + if len(l.Line) == 0 { + ss = append(ss, locStr) + } + for li := range l.Line { + lnStr := "??" + if fn := l.Line[li].Function; fn != nil { + lnStr = fmt.Sprintf("%s %s:%d s=%d", + fn.Name, + fn.Filename, + l.Line[li].Line, + fn.StartLine) + if fn.Name != fn.SystemName { + lnStr = lnStr + "(" + fn.SystemName + ")" + } + } + ss = append(ss, locStr+lnStr) + // Do not print location details past the first line + locStr = " " + } + return strings.Join(ss, "\n") +} + +// string dumps a text representation of a sample. Intended mainly +// for debugging purposes. +func (s *Sample) string() string { + ss := []string{} + var sv string + for _, v := range s.Value { + sv = fmt.Sprintf("%s %10d", sv, v) + } + sv = sv + ": " + for _, l := range s.Location { + sv = sv + fmt.Sprintf("%d ", l.ID) + } + ss = append(ss, sv) + const labelHeader = " " + if len(s.Label) > 0 { + ss = append(ss, labelHeader+labelsToString(s.Label)) + } + if len(s.NumLabel) > 0 { + ss = append(ss, labelHeader+numLabelsToString(s.NumLabel, s.NumUnit)) + } + return strings.Join(ss, "\n") +} + +// labelsToString returns a string representation of a +// map representing labels. +func labelsToString(labels map[string][]string) string { + ls := []string{} + for k, v := range labels { + ls = append(ls, fmt.Sprintf("%s:%v", k, v)) + } + sort.Strings(ls) + return strings.Join(ls, " ") +} + +// numLabelsToString returns a string representation of a map +// representing numeric labels. +func numLabelsToString(numLabels map[string][]int64, numUnits map[string][]string) string { + ls := []string{} + for k, v := range numLabels { + units := numUnits[k] + var labelString string + if len(units) == len(v) { + values := make([]string, len(v)) + for i, vv := range v { + values[i] = fmt.Sprintf("%d %s", vv, units[i]) + } + labelString = fmt.Sprintf("%s:%v", k, values) + } else { + labelString = fmt.Sprintf("%s:%v", k, v) + } + ls = append(ls, labelString) + } + sort.Strings(ls) + return strings.Join(ls, " ") +} + +// SetLabel sets the specified key to the specified value for all samples in the +// profile. +func (p *Profile) SetLabel(key string, value []string) { + for _, sample := range p.Sample { + if sample.Label == nil { + sample.Label = map[string][]string{key: value} + } else { + sample.Label[key] = value + } + } +} + +// RemoveLabel removes all labels associated with the specified key for all +// samples in the profile. +func (p *Profile) RemoveLabel(key string) { + for _, sample := range p.Sample { + delete(sample.Label, key) + } +} + +// HasLabel returns true if a sample has a label with indicated key and value. +func (s *Sample) HasLabel(key, value string) bool { + for _, v := range s.Label[key] { + if v == value { + return true + } + } + return false +} + +// DiffBaseSample returns true if a sample belongs to the diff base and false +// otherwise. +func (s *Sample) DiffBaseSample() bool { + return s.HasLabel("pprof::base", "true") +} + +// Scale multiplies all sample values in a profile by a constant and keeps +// only samples that have at least one non-zero value. +func (p *Profile) Scale(ratio float64) { + if ratio == 1 { + return + } + ratios := make([]float64, len(p.SampleType)) + for i := range p.SampleType { + ratios[i] = ratio + } + p.ScaleN(ratios) +} + +// ScaleN multiplies each sample values in a sample by a different amount +// and keeps only samples that have at least one non-zero value. +func (p *Profile) ScaleN(ratios []float64) error { + if len(p.SampleType) != len(ratios) { + return fmt.Errorf("mismatched scale ratios, got %d, want %d", len(ratios), len(p.SampleType)) + } + allOnes := true + for _, r := range ratios { + if r != 1 { + allOnes = false + break + } + } + if allOnes { + return nil + } + fillIdx := 0 + for _, s := range p.Sample { + keepSample := false + for i, v := range s.Value { + if ratios[i] != 1 { + val := int64(math.Round(float64(v) * ratios[i])) + s.Value[i] = val + keepSample = keepSample || val != 0 + } + } + if keepSample { + p.Sample[fillIdx] = s + fillIdx++ + } + } + p.Sample = p.Sample[:fillIdx] + return nil +} + +// HasFunctions determines if all locations in this profile have +// symbolized function information. +func (p *Profile) HasFunctions() bool { + for _, l := range p.Location { + if l.Mapping != nil && !l.Mapping.HasFunctions { + return false + } + } + return true +} + +// HasFileLines determines if all locations in this profile have +// symbolized file and line number information. +func (p *Profile) HasFileLines() bool { + for _, l := range p.Location { + if l.Mapping != nil && (!l.Mapping.HasFilenames || !l.Mapping.HasLineNumbers) { + return false + } + } + return true +} + +// Unsymbolizable returns true if a mapping points to a binary for which +// locations can't be symbolized in principle, at least now. Examples are +// "[vdso]", [vsyscall]" and some others, see the code. +func (m *Mapping) Unsymbolizable() bool { + name := filepath.Base(m.File) + return strings.HasPrefix(name, "[") || strings.HasPrefix(name, "linux-vdso") || strings.HasPrefix(m.File, "/dev/dri/") +} + +// Copy makes a fully independent copy of a profile. +func (p *Profile) Copy() *Profile { + pp := &Profile{} + if err := unmarshal(serialize(p), pp); err != nil { + panic(err) + } + if err := pp.postDecode(); err != nil { + panic(err) + } + + return pp +} diff --git a/vendor/github.com/google/pprof/profile/proto.go b/vendor/github.com/google/pprof/profile/proto.go new file mode 100644 index 00000000000..539ad3ab33f --- /dev/null +++ b/vendor/github.com/google/pprof/profile/proto.go @@ -0,0 +1,370 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is a simple protocol buffer encoder and decoder. +// The format is described at +// https://developers.google.com/protocol-buffers/docs/encoding +// +// A protocol message must implement the message interface: +// decoder() []decoder +// encode(*buffer) +// +// The decode method returns a slice indexed by field number that gives the +// function to decode that field. +// The encode method encodes its receiver into the given buffer. +// +// The two methods are simple enough to be implemented by hand rather than +// by using a protocol compiler. +// +// See profile.go for examples of messages implementing this interface. +// +// There is no support for groups, message sets, or "has" bits. + +package profile + +import ( + "errors" + "fmt" +) + +type buffer struct { + field int // field tag + typ int // proto wire type code for field + u64 uint64 + data []byte + tmp [16]byte +} + +type decoder func(*buffer, message) error + +type message interface { + decoder() []decoder + encode(*buffer) +} + +func marshal(m message) []byte { + var b buffer + m.encode(&b) + return b.data +} + +func encodeVarint(b *buffer, x uint64) { + for x >= 128 { + b.data = append(b.data, byte(x)|0x80) + x >>= 7 + } + b.data = append(b.data, byte(x)) +} + +func encodeLength(b *buffer, tag int, len int) { + encodeVarint(b, uint64(tag)<<3|2) + encodeVarint(b, uint64(len)) +} + +func encodeUint64(b *buffer, tag int, x uint64) { + // append varint to b.data + encodeVarint(b, uint64(tag)<<3) + encodeVarint(b, x) +} + +func encodeUint64s(b *buffer, tag int, x []uint64) { + if len(x) > 2 { + // Use packed encoding + n1 := len(b.data) + for _, u := range x { + encodeVarint(b, u) + } + n2 := len(b.data) + encodeLength(b, tag, n2-n1) + n3 := len(b.data) + copy(b.tmp[:], b.data[n2:n3]) + copy(b.data[n1+(n3-n2):], b.data[n1:n2]) + copy(b.data[n1:], b.tmp[:n3-n2]) + return + } + for _, u := range x { + encodeUint64(b, tag, u) + } +} + +func encodeUint64Opt(b *buffer, tag int, x uint64) { + if x == 0 { + return + } + encodeUint64(b, tag, x) +} + +func encodeInt64(b *buffer, tag int, x int64) { + u := uint64(x) + encodeUint64(b, tag, u) +} + +func encodeInt64s(b *buffer, tag int, x []int64) { + if len(x) > 2 { + // Use packed encoding + n1 := len(b.data) + for _, u := range x { + encodeVarint(b, uint64(u)) + } + n2 := len(b.data) + encodeLength(b, tag, n2-n1) + n3 := len(b.data) + copy(b.tmp[:], b.data[n2:n3]) + copy(b.data[n1+(n3-n2):], b.data[n1:n2]) + copy(b.data[n1:], b.tmp[:n3-n2]) + return + } + for _, u := range x { + encodeInt64(b, tag, u) + } +} + +func encodeInt64Opt(b *buffer, tag int, x int64) { + if x == 0 { + return + } + encodeInt64(b, tag, x) +} + +func encodeString(b *buffer, tag int, x string) { + encodeLength(b, tag, len(x)) + b.data = append(b.data, x...) +} + +func encodeStrings(b *buffer, tag int, x []string) { + for _, s := range x { + encodeString(b, tag, s) + } +} + +func encodeBool(b *buffer, tag int, x bool) { + if x { + encodeUint64(b, tag, 1) + } else { + encodeUint64(b, tag, 0) + } +} + +func encodeBoolOpt(b *buffer, tag int, x bool) { + if x { + encodeBool(b, tag, x) + } +} + +func encodeMessage(b *buffer, tag int, m message) { + n1 := len(b.data) + m.encode(b) + n2 := len(b.data) + encodeLength(b, tag, n2-n1) + n3 := len(b.data) + copy(b.tmp[:], b.data[n2:n3]) + copy(b.data[n1+(n3-n2):], b.data[n1:n2]) + copy(b.data[n1:], b.tmp[:n3-n2]) +} + +func unmarshal(data []byte, m message) (err error) { + b := buffer{data: data, typ: 2} + return decodeMessage(&b, m) +} + +func le64(p []byte) uint64 { + return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56 +} + +func le32(p []byte) uint32 { + return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 +} + +func decodeVarint(data []byte) (uint64, []byte, error) { + var u uint64 + for i := 0; ; i++ { + if i >= 10 || i >= len(data) { + return 0, nil, errors.New("bad varint") + } + u |= uint64(data[i]&0x7F) << uint(7*i) + if data[i]&0x80 == 0 { + return u, data[i+1:], nil + } + } +} + +func decodeField(b *buffer, data []byte) ([]byte, error) { + x, data, err := decodeVarint(data) + if err != nil { + return nil, err + } + b.field = int(x >> 3) + b.typ = int(x & 7) + b.data = nil + b.u64 = 0 + switch b.typ { + case 0: + b.u64, data, err = decodeVarint(data) + if err != nil { + return nil, err + } + case 1: + if len(data) < 8 { + return nil, errors.New("not enough data") + } + b.u64 = le64(data[:8]) + data = data[8:] + case 2: + var n uint64 + n, data, err = decodeVarint(data) + if err != nil { + return nil, err + } + if n > uint64(len(data)) { + return nil, errors.New("too much data") + } + b.data = data[:n] + data = data[n:] + case 5: + if len(data) < 4 { + return nil, errors.New("not enough data") + } + b.u64 = uint64(le32(data[:4])) + data = data[4:] + default: + return nil, fmt.Errorf("unknown wire type: %d", b.typ) + } + + return data, nil +} + +func checkType(b *buffer, typ int) error { + if b.typ != typ { + return errors.New("type mismatch") + } + return nil +} + +func decodeMessage(b *buffer, m message) error { + if err := checkType(b, 2); err != nil { + return err + } + dec := m.decoder() + data := b.data + for len(data) > 0 { + // pull varint field# + type + var err error + data, err = decodeField(b, data) + if err != nil { + return err + } + if b.field >= len(dec) || dec[b.field] == nil { + continue + } + if err := dec[b.field](b, m); err != nil { + return err + } + } + return nil +} + +func decodeInt64(b *buffer, x *int64) error { + if err := checkType(b, 0); err != nil { + return err + } + *x = int64(b.u64) + return nil +} + +func decodeInt64s(b *buffer, x *[]int64) error { + if b.typ == 2 { + // Packed encoding + data := b.data + tmp := make([]int64, 0, len(data)) // Maximally sized + for len(data) > 0 { + var u uint64 + var err error + + if u, data, err = decodeVarint(data); err != nil { + return err + } + tmp = append(tmp, int64(u)) + } + *x = append(*x, tmp...) + return nil + } + var i int64 + if err := decodeInt64(b, &i); err != nil { + return err + } + *x = append(*x, i) + return nil +} + +func decodeUint64(b *buffer, x *uint64) error { + if err := checkType(b, 0); err != nil { + return err + } + *x = b.u64 + return nil +} + +func decodeUint64s(b *buffer, x *[]uint64) error { + if b.typ == 2 { + data := b.data + // Packed encoding + tmp := make([]uint64, 0, len(data)) // Maximally sized + for len(data) > 0 { + var u uint64 + var err error + + if u, data, err = decodeVarint(data); err != nil { + return err + } + tmp = append(tmp, u) + } + *x = append(*x, tmp...) + return nil + } + var u uint64 + if err := decodeUint64(b, &u); err != nil { + return err + } + *x = append(*x, u) + return nil +} + +func decodeString(b *buffer, x *string) error { + if err := checkType(b, 2); err != nil { + return err + } + *x = string(b.data) + return nil +} + +func decodeStrings(b *buffer, x *[]string) error { + var s string + if err := decodeString(b, &s); err != nil { + return err + } + *x = append(*x, s) + return nil +} + +func decodeBool(b *buffer, x *bool) error { + if err := checkType(b, 0); err != nil { + return err + } + if int64(b.u64) == 0 { + *x = false + } else { + *x = true + } + return nil +} diff --git a/vendor/github.com/google/pprof/profile/prune.go b/vendor/github.com/google/pprof/profile/prune.go new file mode 100644 index 00000000000..02d21a81846 --- /dev/null +++ b/vendor/github.com/google/pprof/profile/prune.go @@ -0,0 +1,178 @@ +// Copyright 2014 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Implements methods to remove frames from profiles. + +package profile + +import ( + "fmt" + "regexp" + "strings" +) + +var ( + reservedNames = []string{"(anonymous namespace)", "operator()"} + bracketRx = func() *regexp.Regexp { + var quotedNames []string + for _, name := range append(reservedNames, "(") { + quotedNames = append(quotedNames, regexp.QuoteMeta(name)) + } + return regexp.MustCompile(strings.Join(quotedNames, "|")) + }() +) + +// simplifyFunc does some primitive simplification of function names. +func simplifyFunc(f string) string { + // Account for leading '.' on the PPC ELF v1 ABI. + funcName := strings.TrimPrefix(f, ".") + // Account for unsimplified names -- try to remove the argument list by trimming + // starting from the first '(', but skipping reserved names that have '('. + for _, ind := range bracketRx.FindAllStringSubmatchIndex(funcName, -1) { + foundReserved := false + for _, res := range reservedNames { + if funcName[ind[0]:ind[1]] == res { + foundReserved = true + break + } + } + if !foundReserved { + funcName = funcName[:ind[0]] + break + } + } + return funcName +} + +// Prune removes all nodes beneath a node matching dropRx, and not +// matching keepRx. If the root node of a Sample matches, the sample +// will have an empty stack. +func (p *Profile) Prune(dropRx, keepRx *regexp.Regexp) { + prune := make(map[uint64]bool) + pruneBeneath := make(map[uint64]bool) + + for _, loc := range p.Location { + var i int + for i = len(loc.Line) - 1; i >= 0; i-- { + if fn := loc.Line[i].Function; fn != nil && fn.Name != "" { + funcName := simplifyFunc(fn.Name) + if dropRx.MatchString(funcName) { + if keepRx == nil || !keepRx.MatchString(funcName) { + break + } + } + } + } + + if i >= 0 { + // Found matching entry to prune. + pruneBeneath[loc.ID] = true + + // Remove the matching location. + if i == len(loc.Line)-1 { + // Matched the top entry: prune the whole location. + prune[loc.ID] = true + } else { + loc.Line = loc.Line[i+1:] + } + } + } + + // Prune locs from each Sample + for _, sample := range p.Sample { + // Scan from the root to the leaves to find the prune location. + // Do not prune frames before the first user frame, to avoid + // pruning everything. + foundUser := false + for i := len(sample.Location) - 1; i >= 0; i-- { + id := sample.Location[i].ID + if !prune[id] && !pruneBeneath[id] { + foundUser = true + continue + } + if !foundUser { + continue + } + if prune[id] { + sample.Location = sample.Location[i+1:] + break + } + if pruneBeneath[id] { + sample.Location = sample.Location[i:] + break + } + } + } +} + +// RemoveUninteresting prunes and elides profiles using built-in +// tables of uninteresting function names. +func (p *Profile) RemoveUninteresting() error { + var keep, drop *regexp.Regexp + var err error + + if p.DropFrames != "" { + if drop, err = regexp.Compile("^(" + p.DropFrames + ")$"); err != nil { + return fmt.Errorf("failed to compile regexp %s: %v", p.DropFrames, err) + } + if p.KeepFrames != "" { + if keep, err = regexp.Compile("^(" + p.KeepFrames + ")$"); err != nil { + return fmt.Errorf("failed to compile regexp %s: %v", p.KeepFrames, err) + } + } + p.Prune(drop, keep) + } + return nil +} + +// PruneFrom removes all nodes beneath the lowest node matching dropRx, not including itself. +// +// Please see the example below to understand this method as well as +// the difference from Prune method. +// +// A sample contains Location of [A,B,C,B,D] where D is the top frame and there's no inline. +// +// PruneFrom(A) returns [A,B,C,B,D] because there's no node beneath A. +// Prune(A, nil) returns [B,C,B,D] by removing A itself. +// +// PruneFrom(B) returns [B,C,B,D] by removing all nodes beneath the first B when scanning from the bottom. +// Prune(B, nil) returns [D] because a matching node is found by scanning from the root. +func (p *Profile) PruneFrom(dropRx *regexp.Regexp) { + pruneBeneath := make(map[uint64]bool) + + for _, loc := range p.Location { + for i := 0; i < len(loc.Line); i++ { + if fn := loc.Line[i].Function; fn != nil && fn.Name != "" { + funcName := simplifyFunc(fn.Name) + if dropRx.MatchString(funcName) { + // Found matching entry to prune. + pruneBeneath[loc.ID] = true + loc.Line = loc.Line[i:] + break + } + } + } + } + + // Prune locs from each Sample + for _, sample := range p.Sample { + // Scan from the bottom leaf to the root to find the prune location. + for i, loc := range sample.Location { + if pruneBeneath[loc.ID] { + sample.Location = sample.Location[i:] + break + } + } + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/README.md b/vendor/github.com/lucas-clemente/quic-go/README.md deleted file mode 100644 index fc25e0b48ee..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/README.md +++ /dev/null @@ -1,61 +0,0 @@ -# A QUIC implementation in pure Go - - - -[![PkgGoDev](https://pkg.go.dev/badge/github.com/lucas-clemente/quic-go)](https://pkg.go.dev/github.com/lucas-clemente/quic-go) -[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/) - -quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the [Unreliable Datagram Extension, RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221). It has support for HTTP/3 [RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114). - -In addition the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. - -## Guides - -*We currently support Go 1.16.x, Go 1.17.x, and Go 1.18.x.* - -Running tests: - - go test ./... - -### QUIC without HTTP/3 - -Take a look at [this echo example](example/echo/echo.go). - -## Usage - -### As a server - -See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: - -```go -http.Handle("/", http.FileServer(http.Dir(wwwDir))) -http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) -``` - -### As a client - -See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`. - -```go -http.Client{ - Transport: &http3.RoundTripper{}, -} -``` - -## Projects using quic-go - -| Project | Description | Stars | -|------------------------------------------------------|--------------------------------------------------------------------------------------------------------|-------| -| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | -| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | -| [go-ipfs](https://github.com/ipfs/go-ipfs) | IPFS implementation in go | ![GitHub Repo stars](https://img.shields.io/github/stars/ipfs/go-ipfs?style=flat-square) | -| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | -| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | -| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | -| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | -| [OONI Probe](https://github.com/ooni/probe-cli) | The Open Observatory of Network Interference (OONI) aims to empower decentralized efforts in documenting Internet censorship around the world. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | - - -## Contributing - -We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. diff --git a/vendor/github.com/lucas-clemente/quic-go/client.go b/vendor/github.com/lucas-clemente/quic-go/client.go deleted file mode 100644 index 76b5906130a..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/client.go +++ /dev/null @@ -1,338 +0,0 @@ -package quic - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net" - "strings" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/logging" -) - -type client struct { - sconn sendConn - // If the client is created with DialAddr, we create a packet conn. - // If it is started with Dial, we take a packet conn as a parameter. - createdPacketConn bool - - use0RTT bool - - packetHandlers packetHandlerManager - - tlsConf *tls.Config - config *Config - - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID - - initialPacketNumber protocol.PacketNumber - hasNegotiatedVersion bool - version protocol.VersionNumber - - handshakeChan chan struct{} - - conn quicConn - - tracer logging.ConnectionTracer - tracingID uint64 - logger utils.Logger -} - -var ( - // make it possible to mock connection ID for initial generation in the tests - generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial -) - -// DialAddr establishes a new QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. -func DialAddr( - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return DialAddrContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -// The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. -func DialAddrEarly( - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. -// See DialAddrEarly for details -func DialAddrEarlyContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) - if err != nil { - return nil, err - } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") - return conn, nil -} - -// DialAddrContext establishes a new QUIC connection to a server using the provided context. -// See DialAddr for details. -func DialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - -func dialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, -) (quicConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) -} - -// Dial establishes a new QUIC connection to a server using a net.PacketConn. If -// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. The same PacketConn can be used for multiple calls to Dial and -// Listen, QUIC connection IDs are used for demultiplexing the different -// connections. The host parameter is used for SNI. The tls.Config must define -// an application protocol (using NextProtos). -func Dial( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) -} - -// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. -// The same PacketConn can be used for multiple calls to Dial and Listen, -// QUIC connection IDs are used for demultiplexing the different connections. -// The host parameter is used for SNI. -// The tls.Config must define an application protocol (using NextProtos). -func DialEarly( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) -} - -// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. -// See DialEarly for details. -func DialEarlyContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) -} - -// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. -// See Dial for details. -func DialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) -} - -func dialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, - createdPacketConn bool, -) (quicConn, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } - c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, ConnectionTracingKey, c.tracingID), - protocol.PerspectiveClient, - c.destConnID, - ) - } - if c.tracer != nil { - c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) - } - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - -func newClient( - pconn net.PacketConn, - remoteAddr net.Addr, - config *Config, - tlsConf *tls.Config, - host string, - use0RTT bool, - createdPacketConn bool, -) (*client, error) { - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - if tlsConf.ServerName == "" { - sni := host - if strings.IndexByte(sni, ':') != -1 { - var err error - sni, _, err = net.SplitHostPort(sni) - if err != nil { - return nil, err - } - } - - tlsConf.ServerName = sni - } - - // check that all versions are actually supported - if config != nil { - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - } - - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() - if err != nil { - return nil, err - } - destConnID, err := generateConnectionIDForInitial() - if err != nil { - return nil, err - } - c := &client{ - srcConnID: srcConnID, - destConnID: destConnID, - sconn: newSendPconn(pconn, remoteAddr), - createdPacketConn: createdPacketConn, - use0RTT: use0RTT, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), - } - return c, nil -} - -func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - - c.conn = newClientConnection( - c.sconn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - ) - c.packetHandlers.Add(c.srcConnID, c.conn) - - errorChan := make(chan error, 1) - go func() { - err := c.conn.run() // returns as soon as the connection is closed - - if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { - c.packetHandlers.Destroy() - } - errorChan <- err - }() - - // only set when we're using 0-RTT - // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. - var earlyConnChan <-chan struct{} - if c.use0RTT { - earlyConnChan = c.conn.earlyConnReady() - } - - select { - case <-ctx.Done(): - c.conn.shutdown() - return ctx.Err() - case err := <-errorChan: - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - } - return err - case <-earlyConnChan: - // ready to send 0-RTT data - return nil - case <-c.conn.HandshakeComplete().Done(): - // handshake successfully completed - return nil - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/closed_conn.go b/vendor/github.com/lucas-clemente/quic-go/closed_conn.go deleted file mode 100644 index 35c2d7390a5..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/closed_conn.go +++ /dev/null @@ -1,112 +0,0 @@ -package quic - -import ( - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -// A closedLocalConn is a connection that we closed locally. -// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, -// with an exponential backoff. -type closedLocalConn struct { - conn sendConn - connClosePacket []byte - - closeOnce sync.Once - closeChan chan struct{} // is closed when the connection is closed or destroyed - - receivedPackets chan *receivedPacket - counter uint64 // number of packets received - - perspective protocol.Perspective - - logger utils.Logger -} - -var _ packetHandler = &closedLocalConn{} - -// newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn( - conn sendConn, - connClosePacket []byte, - perspective protocol.Perspective, - logger utils.Logger, -) packetHandler { - s := &closedLocalConn{ - conn: conn, - connClosePacket: connClosePacket, - perspective: perspective, - logger: logger, - closeChan: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, 64), - } - go s.run() - return s -} - -func (s *closedLocalConn) run() { - for { - select { - case p := <-s.receivedPackets: - s.handlePacketImpl(p) - case <-s.closeChan: - return - } - } -} - -func (s *closedLocalConn) handlePacket(p *receivedPacket) { - select { - case s.receivedPackets <- p: - default: - } -} - -func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { - s.counter++ - // exponential backoff - // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving - for n := s.counter; n > 1; n = n / 2 { - if n%2 != 0 { - return - } - } - s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter) - if err := s.conn.Write(s.connClosePacket); err != nil { - s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) - } -} - -func (s *closedLocalConn) shutdown() { - s.destroy(nil) -} - -func (s *closedLocalConn) destroy(error) { - s.closeOnce.Do(func() { - close(s.closeChan) - }) -} - -func (s *closedLocalConn) getPerspective() protocol.Perspective { - return s.perspective -} - -// A closedRemoteConn is a connection that was closed remotely. -// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. -// We can just ignore those packets. -type closedRemoteConn struct { - perspective protocol.Perspective -} - -var _ packetHandler = &closedRemoteConn{} - -func newClosedRemoteConn(pers protocol.Perspective) packetHandler { - return &closedRemoteConn{perspective: pers} -} - -func (s *closedRemoteConn) handlePacket(*receivedPacket) {} -func (s *closedRemoteConn) shutdown() {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go b/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go deleted file mode 100644 index e23f79b0ee2..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/datagram_queue.go +++ /dev/null @@ -1,94 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - rcvQueue chan []byte - - closeErr error - closed chan struct{} - - hasData func() - - dequeued chan error - - logger utils.Logger -} - -func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { - return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), - dequeued: make(chan error), - closed: make(chan struct{}), - logger: logger, - } -} - -// AddAndWait queues a new DATAGRAM frame for sending. -// It blocks until the frame has been dequeued. -func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { - select { - case h.sendQueue <- f: - h.hasData() - case <-h.closed: - return h.closeErr - } - - select { - case err := <-h.dequeued: - return err - case <-h.closed: - return h.closeErr - } -} - -// Get dequeues a DATAGRAM frame for sending. -func (h *datagramQueue) Get(maxDatagramSize protocol.ByteCount, version protocol.VersionNumber) *wire.DatagramFrame { - select { - case f := <-h.sendQueue: - datagramSize := f.Length(version) - if datagramSize > maxDatagramSize { - h.dequeued <- fmt.Errorf("datagram size %d exceed current limit of %d", datagramSize, maxDatagramSize) - return nil - } - h.dequeued <- nil - return f - default: - return nil - } -} - -// HandleDatagramFrame handles a received DATAGRAM frame. -func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { - data := make([]byte, len(f.Data)) - copy(data, f.Data) - select { - case h.rcvQueue <- data: - default: - h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) - } -} - -// Receive gets a received DATAGRAM frame. -func (h *datagramQueue) Receive() ([]byte, error) { - select { - case data := <-h.rcvQueue: - return data, nil - case <-h.closed: - return nil, h.closeErr - } -} - -func (h *datagramQueue) CloseWithError(e error) { - h.closeErr = e - close(h.closed) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go deleted file mode 100644 index 26291321ca2..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ackhandler.go +++ /dev/null @@ -1,21 +0,0 @@ -package ackhandler - -import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/logging" -) - -// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler -func NewAckHandler( - initialPacketNumber protocol.PacketNumber, - initialMaxDatagramSize protocol.ByteCount, - rttStats *utils.RTTStats, - pers protocol.Perspective, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, -) (SentPacketHandler, ReceivedPacketHandler) { - sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger) - return sph, newReceivedPacketHandler(sph, rttStats, logger, version) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go deleted file mode 100644 index aed6038d970..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/frame.go +++ /dev/null @@ -1,9 +0,0 @@ -package ackhandler - -import "github.com/lucas-clemente/quic-go/internal/wire" - -type Frame struct { - wire.Frame // nil if the frame has already been acknowledged in another packet - OnLost func(wire.Frame) - OnAcked func(wire.Frame) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go deleted file mode 100644 index 32235f81ab1..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/gen.go +++ /dev/null @@ -1,3 +0,0 @@ -package ackhandler - -//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go deleted file mode 100644 index e957d253a1c..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package ackhandler - -//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/lucas-clemente/quic-go/internal/ackhandler sentPacketTracker" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go deleted file mode 100644 index bb74f4ef933..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package ackhandler - -// Linked list implementation from the Go standard library. - -// PacketElement is an element of a linked list. -type PacketElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *PacketElement - - // The list to which this element belongs. - list *PacketList - - // The value stored with this element. - Value Packet -} - -// Next returns the next list element or nil. -func (e *PacketElement) Next() *PacketElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *PacketElement) Prev() *PacketElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// PacketList is a linked list of Packets. -type PacketList struct { - root PacketElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *PacketList) Init() *PacketList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewPacketList returns an initialized list. -func NewPacketList() *PacketList { return new(PacketList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *PacketList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *PacketList) Front() *PacketElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *PacketList) Back() *PacketElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *PacketList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *PacketList) insert(e, at *PacketElement) *PacketElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement { - return l.insert(&PacketElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *PacketList) remove(e *PacketElement) *PacketElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *PacketList) Remove(e *PacketElement) Packet { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *PacketList) PushFront(v Packet) *PacketElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *PacketList) PushBack(v Packet) *PacketElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketList) MoveToFront(e *PacketElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketList) MoveToBack(e *PacketElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketList) MoveBefore(e, mark *PacketElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketList) MoveAfter(e, mark *PacketElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketList) PushBackList(other *PacketList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketList) PushFrontList(other *PacketList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go deleted file mode 100644 index c7a8d13ee03..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go +++ /dev/null @@ -1,3 +0,0 @@ -package handshake - -//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner" diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go b/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go deleted file mode 100644 index b7cb20c151e..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/retry.go +++ /dev/null @@ -1,62 +0,0 @@ -package handshake - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "fmt" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -var ( - oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34 - retryAEAD cipher.AEAD // used for QUIC draft-34 -) - -func init() { - oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) - retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) -} - -func initAEAD(key [16]byte) cipher.AEAD { - aes, err := aes.NewCipher(key[:]) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - return aead -} - -var ( - retryBuf bytes.Buffer - retryMutex sync.Mutex - oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} - retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} -) - -// GetRetryIntegrityTag calculates the integrity tag on a Retry packet -func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { - retryMutex.Lock() - retryBuf.WriteByte(uint8(origDestConnID.Len())) - retryBuf.Write(origDestConnID.Bytes()) - retryBuf.Write(retry) - - var tag [16]byte - var sealed []byte - if version != protocol.Version1 { - sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes()) - } else { - sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes()) - } - if len(sealed) != 16 { - panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) - } - retryBuf.Reset() - retryMutex.Unlock() - return &tag -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go deleted file mode 100644 index 7ae7d9dcfa1..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go +++ /dev/null @@ -1,81 +0,0 @@ -package protocol - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" -) - -// A ConnectionID in QUIC -type ConnectionID []byte - -const maxConnectionIDLen = 20 - -// GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID(len int) (ConnectionID, error) { - b := make([]byte, len) - if _, err := rand.Read(b); err != nil { - return nil, err - } - return ConnectionID(b), nil -} - -// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. -// It uses a length randomly chosen between 8 and 20 bytes. -func GenerateConnectionIDForInitial() (ConnectionID, error) { - r := make([]byte, 1) - if _, err := rand.Read(r); err != nil { - return nil, err - } - len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) - return GenerateConnectionID(len) -} - -// ReadConnectionID reads a connection ID of length len from the given io.Reader. -// It returns io.EOF if there are not enough bytes to read. -func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { - if len == 0 { - return nil, nil - } - c := make(ConnectionID, len) - _, err := io.ReadFull(r, c) - if err == io.ErrUnexpectedEOF { - return nil, io.EOF - } - return c, err -} - -// Equal says if two connection IDs are equal -func (c ConnectionID) Equal(other ConnectionID) bool { - return bytes.Equal(c, other) -} - -// Len returns the length of the connection ID in bytes -func (c ConnectionID) Len() int { - return len(c) -} - -// Bytes returns the byte representation -func (c ConnectionID) Bytes() []byte { - return []byte(c) -} - -func (c ConnectionID) String() string { - if c.Len() == 0 { - return "(empty)" - } - return fmt.Sprintf("%x", c.Bytes()) -} - -type DefaultConnectionIDGenerator struct { - ConnLen int -} - -func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) { - return GenerateConnectionID(d.ConnLen) -} - -func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int { - return d.ConnLen -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go deleted file mode 100644 index e3024624c24..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go116.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.16 && !go1.17 -// +build go1.16,!go1.17 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-16" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go deleted file mode 100644 index 5de030c78a1..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go118.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build go1.18 && !go1.19 -// +build go1.18,!go1.19 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "net" - "unsafe" - - "github.com/marten-seemann/qtls-go1-18" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-18.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go deleted file mode 100644 index f8d59d8bb09..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go120.go +++ /dev/null @@ -1,6 +0,0 @@ -//go:build go1.20 -// +build go1.20 - -package qtls - -var _ int = "The version of quic-go you're using can't be built on Go 1.20 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go b/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go deleted file mode 100644 index 384d719c6e8..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go_oldversion.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build (go1.9 || go1.10 || go1.11 || go1.12 || go1.13 || go1.14 || go1.15) && !go1.16 -// +build go1.9 go1.10 go1.11 go1.12 go1.13 go1.14 go1.15 -// +build !go1.16 - -package qtls - -var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go deleted file mode 100644 index cf4642504e0..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/atomic_bool.go +++ /dev/null @@ -1,22 +0,0 @@ -package utils - -import "sync/atomic" - -// An AtomicBool is an atomic bool -type AtomicBool struct { - v int32 -} - -// Set sets the value -func (a *AtomicBool) Set(value bool) { - var n int32 - if value { - n = 1 - } - atomic.StoreInt32(&a.v, n) -} - -// Get gets the value -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.v) != 0 -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go deleted file mode 100644 index 8a63e95891e..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/gen.go +++ /dev/null @@ -1,5 +0,0 @@ -package utils - -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval -//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out newconnectionid_linkedlist.go gen Item=NewConnectionID diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go deleted file mode 100644 index ee1f85f624f..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go +++ /dev/null @@ -1,170 +0,0 @@ -package utils - -import ( - "math" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// InfDuration is a duration of infinite length -const InfDuration = time.Duration(math.MaxInt64) - -// Max returns the maximum of two Ints -func Max(a, b int) int { - if a < b { - return b - } - return a -} - -// MaxUint32 returns the maximum of two uint32 -func MaxUint32(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - -// MaxUint64 returns the maximum of two uint64 -func MaxUint64(a, b uint64) uint64 { - if a < b { - return b - } - return a -} - -// MinUint64 returns the maximum of two uint64 -func MinUint64(a, b uint64) uint64 { - if a < b { - return a - } - return b -} - -// Min returns the minimum of two Ints -func Min(a, b int) int { - if a < b { - return a - } - return b -} - -// MinUint32 returns the maximum of two uint32 -func MinUint32(a, b uint32) uint32 { - if a < b { - return a - } - return b -} - -// MinInt64 returns the minimum of two int64 -func MinInt64(a, b int64) int64 { - if a < b { - return a - } - return b -} - -// MaxInt64 returns the minimum of two int64 -func MaxInt64(a, b int64) int64 { - if a > b { - return a - } - return b -} - -// MinByteCount returns the minimum of two ByteCounts -func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return a - } - return b -} - -// MaxByteCount returns the maximum of two ByteCounts -func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount { - if a < b { - return b - } - return a -} - -// MaxDuration returns the max duration -func MaxDuration(a, b time.Duration) time.Duration { - if a > b { - return a - } - return b -} - -// MinDuration returns the minimum duration -func MinDuration(a, b time.Duration) time.Duration { - if a > b { - return b - } - return a -} - -// MinNonZeroDuration return the minimum duration that's not zero. -func MinNonZeroDuration(a, b time.Duration) time.Duration { - if a == 0 { - return b - } - if b == 0 { - return a - } - return MinDuration(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d -} - -// MinTime returns the earlier time -func MinTime(a, b time.Time) time.Time { - if a.After(b) { - return b - } - return a -} - -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - -// MaxTime returns the later time -func MaxTime(a, b time.Time) time.Time { - if a.After(b) { - return a - } - return b -} - -// MaxPacketNumber returns the max packet number -func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a > b { - return a - } - return b -} - -// MinPacketNumber returns the min packet number -func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber { - if a < b { - return a - } - return b -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go deleted file mode 100644 index 694ee7aaf2c..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/new_connection_id.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import ( - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// NewConnectionID is a new connection ID -type NewConnectionID struct { - SequenceNumber uint64 - ConnectionID protocol.ConnectionID - StatelessResetToken protocol.StatelessResetToken -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go deleted file mode 100644 index d59562e536e..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/newconnectionid_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// NewConnectionIDElement is an element of a linked list. -type NewConnectionIDElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *NewConnectionIDElement - - // The list to which this element belongs. - list *NewConnectionIDList - - // The value stored with this element. - Value NewConnectionID -} - -// Next returns the next list element or nil. -func (e *NewConnectionIDElement) Next() *NewConnectionIDElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// NewConnectionIDList is a linked list of NewConnectionIDs. -type NewConnectionIDList struct { - root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *NewConnectionIDList) Init() *NewConnectionIDList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewNewConnectionIDList returns an initialized list. -func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *NewConnectionIDList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Front() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *NewConnectionIDList) Back() *NewConnectionIDElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *NewConnectionIDList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement { - return l.insert(&NewConnectionIDElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go deleted file mode 100644 index 62cc8b9cb31..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packet_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/lucas-clemente/quic-go/internal/protocol" - -// PacketInterval is an interval from one PacketNumber to the other -type PacketInterval struct { - Start protocol.PacketNumber - End protocol.PacketNumber -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go deleted file mode 100644 index b461e85a96e..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/packetinterval_linkedlist.go +++ /dev/null @@ -1,217 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package utils - -// Linked list implementation from the Go standard library. - -// PacketIntervalElement is an element of a linked list. -type PacketIntervalElement struct { - // Next and previous pointers in the doubly-linked list of elements. - // To simplify the implementation, internally a list l is implemented - // as a ring, such that &l.root is both the next element of the last - // list element (l.Back()) and the previous element of the first list - // element (l.Front()). - next, prev *PacketIntervalElement - - // The list to which this element belongs. - list *PacketIntervalList - - // The value stored with this element. - Value PacketInterval -} - -// Next returns the next list element or nil. -func (e *PacketIntervalElement) Next() *PacketIntervalElement { - if p := e.next; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// Prev returns the previous list element or nil. -func (e *PacketIntervalElement) Prev() *PacketIntervalElement { - if p := e.prev; e.list != nil && p != &e.list.root { - return p - } - return nil -} - -// PacketIntervalList is a linked list of PacketIntervals. -type PacketIntervalList struct { - root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element -} - -// Init initializes or clears list l. -func (l *PacketIntervalList) Init() *PacketIntervalList { - l.root.next = &l.root - l.root.prev = &l.root - l.len = 0 - return l -} - -// NewPacketIntervalList returns an initialized list. -func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() } - -// Len returns the number of elements of list l. -// The complexity is O(1). -func (l *PacketIntervalList) Len() int { return l.len } - -// Front returns the first element of list l or nil if the list is empty. -func (l *PacketIntervalList) Front() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.next -} - -// Back returns the last element of list l or nil if the list is empty. -func (l *PacketIntervalList) Back() *PacketIntervalElement { - if l.len == 0 { - return nil - } - return l.root.prev -} - -// lazyInit lazily initializes a zero List value. -func (l *PacketIntervalList) lazyInit() { - if l.root.next == nil { - l.Init() - } -} - -// insert inserts e after at, increments l.len, and returns e. -func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e - e.list = l - l.len++ - return e -} - -// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement { - return l.insert(&PacketIntervalElement{Value: v}, at) -} - -// remove removes e from its list, decrements l.len, and returns e. -func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - e.list = nil - l.len-- - return e -} - -// Remove removes e from l if e is an element of list l. -// It returns the element value e.Value. -// The element must not be nil. -func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval { - if e.list == l { - // if e.list == l, l must have been initialized when e was inserted - // in l or l == nil (e is a zero Element) and l.remove will crash - l.remove(e) - } - return e.Value -} - -// PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, &l.root) -} - -// PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement { - l.lazyInit() - return l.insertValue(v, l.root.prev) -} - -// InsertBefore inserts a new element e with value v immediately before mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark.prev) -} - -// InsertAfter inserts a new element e with value v immediately after mark and returns e. -// If mark is not an element of l, the list is not modified. -// The mark must not be nil. -func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement { - if mark.list != l { - return nil - } - // see comment in List.Remove about initialization of l - return l.insertValue(v, mark) -} - -// MoveToFront moves element e to the front of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) { - if e.list != l || l.root.next == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) -} - -// MoveToBack moves element e to the back of list l. -// If e is not an element of l, the list is not modified. -// The element must not be nil. -func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) { - if e.list != l || l.root.prev == e { - return - } - // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) -} - -// MoveBefore moves element e to its new position before mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark.prev) -} - -// MoveAfter moves element e to its new position after mark. -// If e or mark is not an element of l, or e == mark, the list is not modified. -// The element and mark must not be nil. -func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) { - if e.list != l || e == mark || mark.list != l { - return - } - l.insert(l.remove(e), mark) -} - -// PushBackList inserts a copy of an other list at the back of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { - l.insertValue(e.Value, l.root.prev) - } -} - -// PushFrontList inserts a copy of an other list at the front of list l. -// The lists l and other may be the same. They must not be nil. -func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) { - l.lazyInit() - for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { - l.insertValue(e.Value, &l.root) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go b/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go deleted file mode 100644 index ec16d251b89..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/streamframe_interval.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "github.com/lucas-clemente/quic-go/internal/protocol" - -// ByteInterval is an interval from one ByteCount to the other -type ByteInterval struct { - Start protocol.ByteCount - End protocol.ByteCount -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go deleted file mode 100644 index 9d9edab25c4..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go +++ /dev/null @@ -1,249 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/quicvarint" -) - -// ErrInvalidReservedBits is returned when the reserved bits are incorrect. -// When this error is returned, parsing continues, and an ExtendedHeader is returned. -// This is necessary because we need to decrypt the packet in that case, -// in order to avoid a timing side-channel. -var ErrInvalidReservedBits = errors.New("invalid reserved bits") - -// ExtendedHeader is the header of a QUIC packet. -type ExtendedHeader struct { - Header - - typeByte byte - - KeyPhase protocol.KeyPhaseBit - - PacketNumberLen protocol.PacketNumberLen - PacketNumber protocol.PacketNumber - - parsedLen protocol.ByteCount -} - -func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { - startLen := b.Len() - // read the (now unencrypted) first byte - var err error - h.typeByte, err = b.ReadByte() - if err != nil { - return false, err - } - if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { - return false, err - } - var reservedBitsValid bool - if h.IsLongHeader { - reservedBitsValid, err = h.parseLongHeader(b, v) - } else { - reservedBitsValid, err = h.parseShortHeader(b, v) - } - if err != nil { - return false, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return reservedBitsValid, err -} - -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0xc != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { - h.KeyPhase = protocol.KeyPhaseZero - if h.typeByte&0x4 > 0 { - h.KeyPhase = protocol.KeyPhaseOne - } - - if err := h.readPacketNumber(b); err != nil { - return false, err - } - if h.typeByte&0x18 != 0 { - return false, nil - } - return true, nil -} - -func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { - h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - n, err := b.ReadByte() - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen2: - n, err := utils.BigEndian.ReadUint16(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen3: - n, err := utils.BigEndian.ReadUint24(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - case protocol.PacketNumberLen4: - n, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return err - } - h.PacketNumber = protocol.PacketNumber(n) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// Write writes the Header. -func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error { - if h.DestConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) - } - if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) - } - if h.IsLongHeader { - return h.writeLongHeader(b, ver) - } - return h.writeShortHeader(b, ver) -} - -func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { - var packetType uint8 - if version == protocol.Version2 { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b01 - case protocol.PacketType0RTT: - packetType = 0b10 - case protocol.PacketTypeHandshake: - packetType = 0b11 - case protocol.PacketTypeRetry: - packetType = 0b00 - } - } else { - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0b00 - case protocol.PacketType0RTT: - packetType = 0b01 - case protocol.PacketTypeHandshake: - packetType = 0b10 - case protocol.PacketTypeRetry: - packetType = 0b11 - } - } - firstByte := 0xc0 | packetType<<4 - if h.Type != protocol.PacketTypeRetry { - // Retry packets don't have a packet number - firstByte |= uint8(h.PacketNumberLen - 1) - } - - b.WriteByte(firstByte) - utils.BigEndian.WriteUint32(b, uint32(h.Version)) - b.WriteByte(uint8(h.DestConnectionID.Len())) - b.Write(h.DestConnectionID.Bytes()) - b.WriteByte(uint8(h.SrcConnectionID.Len())) - b.Write(h.SrcConnectionID.Bytes()) - - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeRetry: - b.Write(h.Token) - return nil - case protocol.PacketTypeInitial: - quicvarint.Write(b, uint64(len(h.Token))) - b.Write(h.Token) - } - quicvarint.WriteWithLen(b, uint64(h.Length), 2) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { - typeByte := 0x40 | uint8(h.PacketNumberLen-1) - if h.KeyPhase == protocol.KeyPhaseOne { - typeByte |= byte(1 << 2) - } - - b.WriteByte(typeByte) - b.Write(h.DestConnectionID.Bytes()) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen3: - utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil -} - -// ParsedLen returns the number of bytes that were consumed when parsing the header -func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { - return h.parsedLen -} - -// GetLength determines the length of the Header. -func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount { - if h.IsLongHeader { - length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ - if h.Type == protocol.PacketTypeInitial { - length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) - } - return length - } - - length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len()) - length += protocol.ByteCount(h.PacketNumberLen) - return length -} - -// Log logs the Header -func (h *ExtendedHeader) Log(logger utils.Logger) { - if h.IsLongHeader { - var token string - if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { - if len(h.Token) == 0 { - token = "Token: (empty), " - } else { - token = fmt.Sprintf("Token: %#x, ", h.Token) - } - if h.Type == protocol.PacketTypeRetry { - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) - return - } - } - logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) - } else { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go deleted file mode 100644 index f3a51ecb506..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go +++ /dev/null @@ -1,143 +0,0 @@ -package wire - -import ( - "bytes" - "errors" - "fmt" - "reflect" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" -) - -type frameParser struct { - ackDelayExponent uint8 - - supportsDatagrams bool - - version protocol.VersionNumber -} - -// NewFrameParser creates a new frame parser. -func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser { - return &frameParser{ - supportsDatagrams: supportsDatagrams, - version: v, - } -} - -// ParseNext parses the next frame. -// It skips PADDING frames. -func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) { - for r.Len() != 0 { - typeByte, _ := r.ReadByte() - if typeByte == 0x0 { // PADDING frame - continue - } - r.UnreadByte() - - f, err := p.parseFrame(r, typeByte, encLevel) - if err != nil { - return nil, &qerr.TransportError{ - FrameType: uint64(typeByte), - ErrorCode: qerr.FrameEncodingError, - ErrorMessage: err.Error(), - } - } - return f, nil - } - return nil, nil -} - -func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) { - var frame Frame - var err error - if typeByte&0xf8 == 0x8 { - frame, err = parseStreamFrame(r, p.version) - } else { - switch typeByte { - case 0x1: - frame, err = parsePingFrame(r, p.version) - case 0x2, 0x3: - ackDelayExponent := p.ackDelayExponent - if encLevel != protocol.Encryption1RTT { - ackDelayExponent = protocol.DefaultAckDelayExponent - } - frame, err = parseAckFrame(r, ackDelayExponent, p.version) - case 0x4: - frame, err = parseResetStreamFrame(r, p.version) - case 0x5: - frame, err = parseStopSendingFrame(r, p.version) - case 0x6: - frame, err = parseCryptoFrame(r, p.version) - case 0x7: - frame, err = parseNewTokenFrame(r, p.version) - case 0x10: - frame, err = parseMaxDataFrame(r, p.version) - case 0x11: - frame, err = parseMaxStreamDataFrame(r, p.version) - case 0x12, 0x13: - frame, err = parseMaxStreamsFrame(r, p.version) - case 0x14: - frame, err = parseDataBlockedFrame(r, p.version) - case 0x15: - frame, err = parseStreamDataBlockedFrame(r, p.version) - case 0x16, 0x17: - frame, err = parseStreamsBlockedFrame(r, p.version) - case 0x18: - frame, err = parseNewConnectionIDFrame(r, p.version) - case 0x19: - frame, err = parseRetireConnectionIDFrame(r, p.version) - case 0x1a: - frame, err = parsePathChallengeFrame(r, p.version) - case 0x1b: - frame, err = parsePathResponseFrame(r, p.version) - case 0x1c, 0x1d: - frame, err = parseConnectionCloseFrame(r, p.version) - case 0x1e: - frame, err = parseHandshakeDoneFrame(r, p.version) - case 0x30, 0x31: - if p.supportsDatagrams { - frame, err = parseDatagramFrame(r, p.version) - break - } - fallthrough - default: - err = errors.New("unknown frame type") - } - } - if err != nil { - return nil, err - } - if !p.isAllowedAtEncLevel(frame, encLevel) { - return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) - } - return frame, nil -} - -func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { - switch encLevel { - case protocol.EncryptionInitial, protocol.EncryptionHandshake: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: - return true - default: - return false - } - case protocol.Encryption0RTT: - switch f.(type) { - case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: - return false - default: - return true - } - case protocol.Encryption1RTT: - return true - default: - panic("unknown encryption level") - } -} - -func (p *frameParser) SetAckDelayExponent(exp uint8) { - p.ackDelayExponent = exp -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go deleted file mode 100644 index 158d659f098..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/handshake_done_frame.go +++ /dev/null @@ -1,28 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// A HandshakeDoneFrame is a HANDSHAKE_DONE frame -type HandshakeDoneFrame struct{} - -// ParseHandshakeDoneFrame parses a HandshakeDone frame -func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &HandshakeDoneFrame{}, nil -} - -func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1e) - return nil -} - -// Length of a written frame -func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go b/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go deleted file mode 100644 index dc029e45f61..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ping_frame.go +++ /dev/null @@ -1,27 +0,0 @@ -package wire - -import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// A PingFrame is a PING frame -type PingFrame struct{} - -func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - return &PingFrame{}, nil -} - -func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x1) - return nil -} - -// Length of a written frame -func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount { - return 1 -} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go deleted file mode 100644 index c5aa8a16918..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/logging/mockgen.go +++ /dev/null @@ -1,4 +0,0 @@ -package logging - -//go:generate sh -c "mockgen -package logging -self_package github.com/lucas-clemente/quic-go/logging -destination mock_connection_tracer_test.go github.com/lucas-clemente/quic-go/logging ConnectionTracer" -//go:generate sh -c "mockgen -package logging -self_package github.com/lucas-clemente/quic-go/logging -destination mock_tracer_test.go github.com/lucas-clemente/quic-go/logging Tracer" diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen.go b/vendor/github.com/lucas-clemente/quic-go/mockgen.go deleted file mode 100644 index 22c2c0e74ee..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/mockgen.go +++ /dev/null @@ -1,27 +0,0 @@ -package quic - -//go:generate sh -c "./mockgen_private.sh quic mock_send_conn_test.go github.com/lucas-clemente/quic-go sendConn" -//go:generate sh -c "./mockgen_private.sh quic mock_sender_test.go github.com/lucas-clemente/quic-go sender" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" -//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" -//go:generate sh -c "./mockgen_private.sh quic mock_crypto_data_handler_test.go github.com/lucas-clemente/quic-go cryptoDataHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_frame_source_test.go github.com/lucas-clemente/quic-go frameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" -//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" -//go:generate sh -c "./mockgen_private.sh quic mock_sealing_manager_test.go github.com/lucas-clemente/quic-go sealingManager" -//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" -//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" -//go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/lucas-clemente/quic-go mtuDiscoverer" -//go:generate sh -c "./mockgen_private.sh quic mock_conn_runner_test.go github.com/lucas-clemente/quic-go connRunner" -//go:generate sh -c "./mockgen_private.sh quic mock_quic_conn_test.go github.com/lucas-clemente/quic-go quicConn" -//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler" -//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager" -//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer" -//go:generate sh -c "./mockgen_private.sh quic mock_batch_conn_test.go github.com/lucas-clemente/quic-go batchConn" -//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_token_store_test.go github.com/lucas-clemente/quic-go TokenStore" -//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh b/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh deleted file mode 100644 index 92829d770e0..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -DEST=$2 -PACKAGE=$3 -TMPFILE="mockgen_tmp.go" -# uppercase the name of the interface -ORIG_INTERFACE_NAME=$4 -INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${ORIG_INTERFACE_NAME:0:1})${ORIG_INTERFACE_NAME:1}" - -# Gather all files that contain interface definitions. -# These interfaces might be used as embedded interfaces, -# so we need to pass them to mockgen as aux_files. -AUX=() -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - if $(egrep -qe "type (.*) interface" $f); then - AUX+=("github.com/lucas-clemente/quic-go=$f") - fi -done - -# Find the file that defines the interface we're mocking. -for f in *.go; do - if [[ -z ${f##*_test.go} ]]; then - # skip test files - continue; - fi - INTERFACE=$(sed -n "/^type $ORIG_INTERFACE_NAME interface/,/^}/p" $f) - if [[ -n "$INTERFACE" ]]; then - SRC=$f - break - fi -done - -if [[ -z "$INTERFACE" ]]; then - echo "Interface $ORIG_INTERFACE_NAME not found." - exit 1 -fi - -AUX_FILES=$(IFS=, ; echo "${AUX[*]}") - -## create a public alias for the interface, so that mockgen can process it -echo -e "package $1\n" > $TMPFILE -echo "$INTERFACE" | sed "s/$ORIG_INTERFACE_NAME/$INTERFACE_NAME/" >> $TMPFILE -mockgen -package $1 -self_package $3 -destination $DEST -source=$TMPFILE -aux_files $AUX_FILES -sed "s/$TMPFILE/$SRC/" "$DEST" > "$DEST.new" && mv "$DEST.new" "$DEST" -rm "$TMPFILE" diff --git a/vendor/github.com/lucas-clemente/quic-go/multiplexer.go b/vendor/github.com/lucas-clemente/quic-go/multiplexer.go deleted file mode 100644 index 2271b551722..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/multiplexer.go +++ /dev/null @@ -1,107 +0,0 @@ -package quic - -import ( - "bytes" - "fmt" - "net" - "sync" - - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/logging" -) - -var ( - connMuxerOnce sync.Once - connMuxer multiplexer -) - -type indexableConn interface { - LocalAddr() net.Addr -} - -type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) - RemoveConn(indexableConn) error -} - -type connManager struct { - connIDLen int - statelessResetKey []byte - tracer logging.Tracer - manager packetHandlerManager -} - -// The connMultiplexer listens on multiple net.PacketConns and dispatches -// incoming packets to the connection handler. -type connMultiplexer struct { - mutex sync.Mutex - - conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests - - logger utils.Logger -} - -var _ multiplexer = &connMultiplexer{} - -func getMultiplexer() multiplexer { - connMuxerOnce.Do(func() { - connMuxer = &connMultiplexer{ - conns: make(map[string]connManager), - logger: utils.DefaultLogger.WithPrefix("muxer"), - newPacketHandlerManager: newPacketHandlerMap, - } - }) - return connMuxer -} - -func (m *connMultiplexer) AddConn( - c net.PacketConn, - connIDLen int, - statelessResetKey []byte, - tracer logging.Tracer, -) (packetHandlerManager, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - addr := c.LocalAddr() - connIndex := addr.Network() + " " + addr.String() - p, ok := m.conns[connIndex] - if !ok { - manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) - if err != nil { - return nil, err - } - p = connManager{ - connIDLen: connIDLen, - statelessResetKey: statelessResetKey, - manager: manager, - tracer: tracer, - } - m.conns[connIndex] = p - } else { - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") - } - if tracer != p.tracer { - return nil, fmt.Errorf("cannot use different tracers on the same packet conn") - } - } - return p.manager, nil -} - -func (m *connMultiplexer) RemoveConn(c indexableConn) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() - if _, ok := m.conns[connIndex]; !ok { - return fmt.Errorf("cannote remove connection, connection is unknown") - } - - delete(m.conns, connIndex) - return nil -} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go b/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go deleted file mode 100644 index 2d55a95ef86..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go +++ /dev/null @@ -1,489 +0,0 @@ -package quic - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "hash" - "io" - "log" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" -) - -type zeroRTTQueue struct { - queue []*receivedPacket - retireTimer *time.Timer -} - -var _ packetHandler = &zeroRTTQueue{} - -func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { - if len(h.queue) < protocol.Max0RTTQueueLen { - h.queue = append(h.queue, p) - } -} -func (h *zeroRTTQueue) shutdown() {} -func (h *zeroRTTQueue) destroy(error) {} -func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } -func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { - for _, p := range h.queue { - sess.handlePacket(p) - } -} - -func (h *zeroRTTQueue) Clear() { - for _, p := range h.queue { - p.buffer.Release() - } -} - -// rawConn is a connection that allow reading of a receivedPacket. -type rawConn interface { - ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) - LocalAddr() net.Addr - io.Closer -} - -type packetHandlerMapEntry struct { - packetHandler packetHandler - is0RTTQueue bool -} - -// The packetHandlerMap stores packetHandlers, identified by connection ID. -// It is used: -// * by the server to store connections -// * when multiplexing outgoing connections to store clients -type packetHandlerMap struct { - mutex sync.Mutex - - conn rawConn - connIDLen int - - handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - numZeroRTTEntries int - - listening chan struct{} // is closed when listen returns - closed bool - - deleteRetiredConnsAfter time.Duration - zeroRTTQueueDuration time.Duration - - statelessResetEnabled bool - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash - - tracer logging.Tracer - logger utils.Logger -} - -var _ packetHandlerManager = &packetHandlerMap{} - -func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { - conn, ok := c.(interface{ SetReadBuffer(int) error }) - if !ok { - return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") - } - size, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if size >= protocol.DesiredReceiveBufferSize { - logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) - return nil - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return fmt.Errorf("failed to increase receive buffer size: %w", err) - } - newSize, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if newSize == size { - return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - if newSize < protocol.DesiredReceiveBufferSize { - return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) - return nil -} - -// only print warnings about the UDP receive buffer size once -var receiveBufferWarningOnce sync.Once - -func newPacketHandlerMap( - c net.PacketConn, - connIDLen int, - statelessResetKey []byte, - tracer logging.Tracer, - logger utils.Logger, -) (packetHandlerManager, error) { - if err := setReceiveBuffer(c, logger); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - receiveBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) - } - } - conn, err := wrapConn(c) - if err != nil { - return nil, err - } - m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), - handlers: make(map[string]packetHandlerMapEntry), - resetTokens: make(map[protocol.StatelessResetToken]packetHandler), - deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, - statelessResetEnabled: len(statelessResetKey) > 0, - statelessResetHasher: hmac.New(sha256.New, statelessResetKey), - tracer: tracer, - logger: logger, - } - go m.listen() - - if logger.Debug() { - go m.logUsage() - } - return m, nil -} - -func (h *packetHandlerMap) logUsage() { - ticker := time.NewTicker(2 * time.Second) - var printedZero bool - for { - select { - case <-h.listening: - return - case <-ticker.C: - } - - h.mutex.Lock() - numHandlers := len(h.handlers) - numTokens := len(h.resetTokens) - h.mutex.Unlock() - // If the number tracked handlers and tokens is zero, only print it a single time. - hasZero := numHandlers == 0 && numTokens == 0 - if !hasZero || (hasZero && !printedZero) { - h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) - printedZero = false - if hasZero { - printedZero = true - } - } - } -} - -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { - h.mutex.Lock() - defer h.mutex.Unlock() - - if _, ok := h.handlers[string(id)]; ok { - h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) - return false - } - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} - h.logger.Debugf("Adding connection ID %s.", id) - return true -} - -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { - h.mutex.Lock() - defer h.mutex.Unlock() - - var q *zeroRTTQueue - if entry, ok := h.handlers[string(clientDestConnID)]; ok { - if !entry.is0RTTQueue { - h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) - return false - } - q = entry.packetHandler.(*zeroRTTQueue) - q.retireTimer.Stop() - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - } - sess := fn() - if q != nil { - q.EnqueueAll(sess) - } - h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) - return true -} - -func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { - h.mutex.Lock() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s.", id) -} - -func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { - h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) - time.AfterFunc(h.deleteRetiredConnsAfter, func() { - h.mutex.Lock() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s after it has been retired.", id) - }) -} - -func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { - h.mutex.Lock() - h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} - h.mutex.Unlock() - h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) - - time.AfterFunc(h.deleteRetiredConnsAfter, func() { - h.mutex.Lock() - handler.shutdown() - delete(h.handlers, string(id)) - h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) - }) -} - -func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { - h.mutex.Lock() - h.resetTokens[token] = handler - h.mutex.Unlock() -} - -func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { - h.mutex.Lock() - delete(h.resetTokens, token) - h.mutex.Unlock() -} - -func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { - h.mutex.Lock() - h.server = s - h.mutex.Unlock() -} - -func (h *packetHandlerMap) CloseServer() { - h.mutex.Lock() - if h.server == nil { - h.mutex.Unlock() - return - } - h.server = nil - var wg sync.WaitGroup - for _, entry := range h.handlers { - if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { - wg.Add(1) - go func(handler packetHandler) { - // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - handler.shutdown() - wg.Done() - }(entry.packetHandler) - } - } - h.mutex.Unlock() - wg.Wait() -} - -// Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active connections. -func (h *packetHandlerMap) Destroy() error { - if err := h.conn.Close(); err != nil { - return err - } - <-h.listening // wait until listening returns - return nil -} - -func (h *packetHandlerMap) close(e error) error { - h.mutex.Lock() - if h.closed { - h.mutex.Unlock() - return nil - } - - var wg sync.WaitGroup - for _, entry := range h.handlers { - wg.Add(1) - go func(handler packetHandler) { - handler.destroy(e) - wg.Done() - }(entry.packetHandler) - } - - if h.server != nil { - h.server.setCloseError(e) - } - h.closed = true - h.mutex.Unlock() - wg.Wait() - return getMultiplexer().RemoveConn(h.conn) -} - -func (h *packetHandlerMap) listen() { - defer close(h.listening) - for { - p, err := h.conn.ReadPacket() - //nolint:staticcheck // SA1019 ignore this! - // TODO: This code is used to ignore wsa errors on Windows. - // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. - // See https://github.com/lucas-clemente/quic-go/issues/1737 for details. - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - h.logger.Debugf("Temporary error reading from conn: %w", err) - continue - } - if err != nil { - h.close(err) - return - } - h.handlePacket(p) - } -} - -func (h *packetHandlerMap) handlePacket(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, h.connIDLen) - if err != nil { - h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - p.buffer.MaybeRelease() - return - } - - h.mutex.Lock() - defer h.mutex.Unlock() - - if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } - - if entry, ok := h.handlers[string(connID)]; ok { - if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue - if wire.Is0RTTPacket(p.data) { - entry.packetHandler.handlePacket(p) - return - } - } else { // existing connection - entry.packetHandler.handlePacket(p) - return - } - } - if p.data[0]&0x80 == 0 { - go h.maybeSendStatelessReset(p, connID) - return - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) - return - } - if wire.Is0RTTPacket(p.data) { - if h.numZeroRTTEntries >= protocol.Max0RTTQueues { - return - } - h.numZeroRTTEntries++ - queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[string(connID)] = packetHandlerMapEntry{ - packetHandler: queue, - is0RTTQueue: true, - } - queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { - h.mutex.Lock() - defer h.mutex.Unlock() - // The entry might have been replaced by an actual connection. - // Only delete it if it's still a 0-RTT queue. - if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { - delete(h.handlers, string(connID)) - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - entry.packetHandler.(*zeroRTTQueue).Clear() - if h.logger.Debug() { - h.logger.Debugf("Removing 0-RTT queue for %s.", connID) - } - } - }) - queue.handlePacket(p) - return - } - h.server.handlePacket(p) -} - -func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { - // stateless resets are always short header packets - if data[0]&0x80 != 0 { - return false - } - if len(data) < 17 /* type byte + 16 bytes for the reset token */ { - return false - } - - var token protocol.StatelessResetToken - copy(token[:], data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go sess.destroy(&StatelessResetError{Token: token}) - return true - } - return false -} - -func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { - var token protocol.StatelessResetToken - if !h.statelessResetEnabled { - // Return a random stateless reset token. - // This token will be sent in the server's transport parameters. - // By using a random token, an off-path attacker won't be able to disrupt the connection. - rand.Read(token[:]) - return token - } - h.statelessResetMutex.Lock() - h.statelessResetHasher.Write(connID.Bytes()) - copy(token[:], h.statelessResetHasher.Sum(nil)) - h.statelessResetHasher.Reset() - h.statelessResetMutex.Unlock() - return token -} - -func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { - defer p.buffer.Release() - if !h.statelessResetEnabled { - return - } - // Don't send a stateless reset in response to very small packets. - // This includes packets that could be stateless resets. - if len(p.data) <= protocol.MinStatelessResetSize { - return - } - token := h.GetStatelessResetToken(connID) - h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) - data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) - rand.Read(data) - data[0] = (data[0] & 0x7f) | 0x40 - data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - h.logger.Debugf("Error sending Stateless Reset: %s", err) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go b/vendor/github.com/lucas-clemente/quic-go/packet_packer.go deleted file mode 100644 index efae6102527..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/packet_packer.go +++ /dev/null @@ -1,894 +0,0 @@ -package quic - -import ( - "bytes" - "errors" - "fmt" - "net" - "time" - - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -type packer interface { - PackCoalescedPacket() (*coalescedPacket, error) - PackPacket() (*packedPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) - MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) - PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) - - SetMaxPacketSize(protocol.ByteCount) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) - - HandleTransportParameters(*wire.TransportParameters) - SetToken([]byte) -} - -type sealer interface { - handshake.LongHeaderSealer -} - -type payload struct { - frames []ackhandler.Frame - ack *wire.AckFrame - length protocol.ByteCount -} - -type packedPacket struct { - buffer *packetBuffer - *packetContents -} - -type packetContents struct { - header *wire.ExtendedHeader - ack *wire.AckFrame - frames []ackhandler.Frame - - length protocol.ByteCount - - isMTUProbePacket bool -} - -type coalescedPacket struct { - buffer *packetBuffer - packets []*packetContents -} - -func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { - if !p.header.IsLongHeader { - return protocol.Encryption1RTT - } - //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). - switch p.header.Type { - case protocol.PacketTypeInitial: - return protocol.EncryptionInitial - case protocol.PacketTypeHandshake: - return protocol.EncryptionHandshake - case protocol.PacketType0RTT: - return protocol.Encryption0RTT - default: - panic("can't determine encryption level") - } -} - -func (p *packetContents) IsAckEliciting() bool { - return ackhandler.HasAckElicitingFrames(p.frames) -} - -func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { - largestAcked := protocol.InvalidPacketNumber - if p.ack != nil { - largestAcked = p.ack.LargestAcked() - } - encLevel := p.EncryptionLevel() - for i := range p.frames { - if p.frames[i].OnLost != nil { - continue - } - switch encLevel { - case protocol.EncryptionInitial: - p.frames[i].OnLost = q.AddInitial - case protocol.EncryptionHandshake: - p.frames[i].OnLost = q.AddHandshake - case protocol.Encryption0RTT, protocol.Encryption1RTT: - p.frames[i].OnLost = q.AddAppData - } - } - return &ackhandler.Packet{ - PacketNumber: p.header.PacketNumber, - LargestAcked: largestAcked, - Frames: p.frames, - Length: p.length, - EncryptionLevel: encLevel, - SendTime: now, - IsPathMTUProbePacket: p.isMTUProbePacket, - } -} - -func getMaxPacketSize(addr net.Addr) protocol.ByteCount { - maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if utils.IsIPv4(udpAddr.IP) { - maxSize = protocol.InitialPacketSizeIPv4 - } else { - maxSize = protocol.InitialPacketSizeIPv6 - } - } - return maxSize -} - -type packetNumberManager interface { - PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) - PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber -} - -type sealingManager interface { - GetInitialSealer() (handshake.LongHeaderSealer, error) - GetHandshakeSealer() (handshake.LongHeaderSealer, error) - Get0RTTSealer() (handshake.LongHeaderSealer, error) - Get1RTTSealer() (handshake.ShortHeaderSealer, error) -} - -type frameSource interface { - HasData() bool - AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) -} - -type ackFrameSource interface { - GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame -} - -type packetPacker struct { - srcConnID protocol.ConnectionID - getDestConnID func() protocol.ConnectionID - - perspective protocol.Perspective - version protocol.VersionNumber - cryptoSetup sealingManager - - initialStream cryptoStream - handshakeStream cryptoStream - - token []byte - - pnManager packetNumberManager - framer frameSource - acks ackFrameSource - datagramQueue *datagramQueue - retransmissionQueue *retransmissionQueue - - maxPacketSize protocol.ByteCount - numNonAckElicitingAcks int -} - -var _ packer = &packetPacker{} - -func newPacketPacker( - srcConnID protocol.ConnectionID, - getDestConnID func() protocol.ConnectionID, - initialStream cryptoStream, - handshakeStream cryptoStream, - packetNumberManager packetNumberManager, - retransmissionQueue *retransmissionQueue, - remoteAddr net.Addr, // only used for determining the max packet size - cryptoSetup sealingManager, - framer frameSource, - acks ackFrameSource, - datagramQueue *datagramQueue, - perspective protocol.Perspective, - version protocol.VersionNumber, -) *packetPacker { - return &packetPacker{ - cryptoSetup: cryptoSetup, - getDestConnID: getDestConnID, - srcConnID: srcConnID, - initialStream: initialStream, - handshakeStream: handshakeStream, - retransmissionQueue: retransmissionQueue, - datagramQueue: datagramQueue, - perspective: perspective, - version: version, - framer: framer, - acks: acks, - pnManager: packetNumberManager, - maxPacketSize: getMaxPacketSize(remoteAddr), - } -} - -// PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError) (*coalescedPacket, error) { - var reason string - // don't send details of crypto errors - if !e.ErrorCode.IsCryptoError() { - reason = e.ErrorMessage - } - return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason) -} - -// PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError) (*coalescedPacket, error) { - return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage) -} - -func (p *packetPacker) packConnectionClose( - isApplicationError bool, - errorCode uint64, - frameType uint64, - reason string, -) (*coalescedPacket, error) { - var sealers [4]sealer - var hdrs [4]*wire.ExtendedHeader - var payloads [4]*payload - var size protocol.ByteCount - var numPackets uint8 - encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} - for i, encLevel := range encLevels { - if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { - continue - } - ccf := &wire.ConnectionCloseFrame{ - IsApplicationError: isApplicationError, - ErrorCode: errorCode, - FrameType: frameType, - ReasonPhrase: reason, - } - // don't send application errors in Initial or Handshake packets - if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { - ccf.IsApplicationError = false - ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) - ccf.ReasonPhrase = "" - } - payload := &payload{ - frames: []ackhandler.Frame{{Frame: ccf}}, - length: ccf.Length(p.version), - } - - var sealer sealer - var err error - var keyPhase protocol.KeyPhaseBit // only set for 1-RTT - switch encLevel { - case protocol.EncryptionInitial: - sealer, err = p.cryptoSetup.GetInitialSealer() - case protocol.EncryptionHandshake: - sealer, err = p.cryptoSetup.GetHandshakeSealer() - case protocol.Encryption0RTT: - sealer, err = p.cryptoSetup.Get0RTTSealer() - case protocol.Encryption1RTT: - var s handshake.ShortHeaderSealer - s, err = p.cryptoSetup.Get1RTTSealer() - if err == nil { - keyPhase = s.KeyPhase() - } - sealer = s - } - if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { - continue - } - if err != nil { - return nil, err - } - sealers[i] = sealer - var hdr *wire.ExtendedHeader - if encLevel == protocol.Encryption1RTT { - hdr = p.getShortHeader(keyPhase) - } else { - hdr = p.getLongHeader(encLevel) - } - hdrs[i] = hdr - payloads[i] = payload - size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - numPackets++ - } - contents := make([]*packetContents, 0, numPackets) - buffer := getPacketBuffer() - for i, encLevel := range encLevels { - if sealers[i] == nil { - continue - } - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size) - } - c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false) - if err != nil { - return nil, err - } - contents = append(contents, c) - } - return &coalescedPacket{buffer: buffer, packets: contents}, nil -} - -// packetLength calculates the length of the serialized packet. -// It takes into account that packets that have a tiny payload need to be padded, -// such that len(payload) + packet number len >= 4 + AEAD overhead -func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount { - var paddingLen protocol.ByteCount - pnLen := protocol.ByteCount(hdr.PacketNumberLen) - if payload.length < 4-pnLen { - paddingLen = 4 - pnLen - payload.length - } - return hdr.GetLength(p.version) + payload.length + paddingLen -} - -func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { - var encLevel protocol.EncryptionLevel - var ack *wire.AckFrame - if !handshakeConfirmed { - ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) - if ack != nil { - encLevel = protocol.EncryptionInitial - } else { - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) - if ack != nil { - encLevel = protocol.EncryptionHandshake - } - } - } - if ack == nil { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) - if ack == nil { - return nil, nil - } - encLevel = protocol.Encryption1RTT - } - payload := &payload{ - ack: ack, - length: ack.Length(p.version), - } - - sealer, hdr, err := p.getSealerAndHeader(encLevel) - if err != nil { - return nil, err - } - return p.writeSinglePacket(hdr, payload, encLevel, sealer) -} - -// size is the expected size of the packet, if no padding was applied. -func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { - // For the server, only ack-eliciting Initial packets need to be padded. - if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { - return 0 - } - if size >= p.maxPacketSize { - return 0 - } - return p.maxPacketSize - size -} - -// PackCoalescedPacket packs a new packet. -// It packs an Initial / Handshake if there is data to send in these packet number spaces. -// It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { - maxPacketSize := p.maxPacketSize - if p.perspective == protocol.PerspectiveClient { - maxPacketSize = protocol.MinInitialPacketSize - } - var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader - var initialPayload, handshakePayload, appDataPayload *payload - var numPackets int - // Try packing an Initial packet. - initialSealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil && err != handshake.ErrKeysDropped { - return nil, err - } - var size protocol.ByteCount - if initialSealer != nil { - initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) - if initialPayload != nil { - size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) - numPackets++ - } - } - - // Add a Handshake packet. - var handshakeSealer sealer - if size < maxPacketSize-protocol.MinCoalescedPacketSize { - var err error - handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() - if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { - return nil, err - } - if handshakeSealer != nil { - handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) - if handshakePayload != nil { - s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) - size += s - numPackets++ - } - } - } - - // Add a 0-RTT / 1-RTT packet. - var appDataSealer sealer - appDataEncLevel := protocol.Encryption1RTT - if size < maxPacketSize-protocol.MinCoalescedPacketSize { - var err error - appDataSealer, appDataHdr, appDataPayload = p.maybeGetAppDataPacket(maxPacketSize-size, size) - if err != nil { - return nil, err - } - if appDataHdr != nil { - if appDataHdr.IsLongHeader { - appDataEncLevel = protocol.Encryption0RTT - } - if appDataPayload != nil { - size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) - numPackets++ - } - } - } - - if numPackets == 0 { - return nil, nil - } - - buffer := getPacketBuffer() - packet := &coalescedPacket{ - buffer: buffer, - packets: make([]*packetContents, 0, numPackets), - } - if initialPayload != nil { - padding := p.initialPaddingLen(initialPayload.frames, size) - cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - if handshakePayload != nil { - cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - if appDataPayload != nil { - cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false) - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) - } - return packet, nil -} - -// PackPacket packs a packet in the application data packet number space. -// It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket() (*packedPacket, error) { - sealer, hdr, payload := p.maybeGetAppDataPacket(p.maxPacketSize, 0) - if payload == nil { - return nil, nil - } - buffer := getPacketBuffer() - encLevel := protocol.Encryption1RTT - if hdr.IsLongHeader { - encLevel = protocol.Encryption0RTT - } - cont, err := p.appendPacket(buffer, hdr, payload, 0, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: cont, - }, nil -} - -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { - var s cryptoStream - var hasRetransmission bool - //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. - switch encLevel { - case protocol.EncryptionInitial: - s = p.initialStream - hasRetransmission = p.retransmissionQueue.HasInitialData() - case protocol.EncryptionHandshake: - s = p.handshakeStream - hasRetransmission = p.retransmissionQueue.HasHandshakeData() - } - - hasData := s.HasData() - var ack *wire.AckFrame - if encLevel == protocol.EncryptionInitial || currentSize == 0 { - ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) - } - if !hasData && !hasRetransmission && ack == nil { - // nothing to send - return nil, nil - } - - var payload payload - if ack != nil { - payload.ack = ack - payload.length = ack.Length(p.version) - maxPacketSize -= payload.length - } - hdr := p.getLongHeader(encLevel) - maxPacketSize -= hdr.GetLength(p.version) - if hasRetransmission { - for { - var f wire.Frame - //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s - switch encLevel { - case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxPacketSize) - case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize) - } - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - frameLen := f.Length(p.version) - payload.length += frameLen - maxPacketSize -= frameLen - } - } else if s.HasData() { - cf := s.PopCryptoFrame(maxPacketSize) - payload.frames = []ackhandler.Frame{{Frame: cf}} - payload.length += cf.Length(p.version) - } - return hdr, &payload -} - -func (p *packetPacker) maybeGetAppDataPacket(maxPacketSize, currentSize protocol.ByteCount) (sealer, *wire.ExtendedHeader, *payload) { - var sealer sealer - var encLevel protocol.EncryptionLevel - var hdr *wire.ExtendedHeader - oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() - if err == nil { - encLevel = protocol.Encryption1RTT - sealer = oneRTTSealer - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - } else { - // 1-RTT sealer not yet available - if p.perspective != protocol.PerspectiveClient { - return nil, nil, nil - } - sealer, err = p.cryptoSetup.Get0RTTSealer() - if sealer == nil || err != nil { - return nil, nil, nil - } - encLevel = protocol.Encryption0RTT - hdr = p.getLongHeader(protocol.Encryption0RTT) - } - - maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, encLevel == protocol.Encryption1RTT && currentSize == 0) - return sealer, hdr, payload -} - -func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { - payload := p.composeNextPacket(maxPayloadSize, ackAllowed) - - // check if we have anything to send - if len(payload.frames) == 0 { - if payload.ack == nil { - return nil - } - // the packet only contains an ACK - if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { - ping := &wire.PingFrame{} - // don't retransmit the PING frame when it is lost - payload.frames = append(payload.frames, ackhandler.Frame{Frame: ping, OnLost: func(wire.Frame) {}}) - payload.length += ping.Length(p.version) - p.numNonAckElicitingAcks = 0 - } else { - p.numNonAckElicitingAcks++ - } - } else { - p.numNonAckElicitingAcks = 0 - } - return payload -} - -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { - payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} - - var hasDatagram bool - if p.datagramQueue != nil { - if datagram := p.datagramQueue.Get(maxFrameSize, p.version); datagram != nil { - payload.frames = append(payload.frames, ackhandler.Frame{ - Frame: datagram, - // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. - OnLost: func(wire.Frame) {}, - }) - payload.length += datagram.Length(p.version) - hasDatagram = true - } - } - - var ack *wire.AckFrame - hasData := p.framer.HasData() - hasRetransmission := p.retransmissionQueue.HasAppData() - // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued - if !hasDatagram && ackAllowed { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) - if ack != nil { - payload.ack = ack - payload.length += ack.Length(p.version) - } - } - - if ack == nil && !hasData && !hasRetransmission { - return payload - } - - if hasRetransmission { - for { - remainingLen := maxFrameSize - payload.length - if remainingLen < protocol.MinStreamFrameSize { - break - } - f := p.retransmissionQueue.GetAppDataFrame(remainingLen) - if f == nil { - break - } - payload.frames = append(payload.frames, ackhandler.Frame{Frame: f}) - payload.length += f.Length(p.version) - } - } - - if hasData { - var lengthAdded protocol.ByteCount - payload.frames, lengthAdded = p.framer.AppendControlFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded - - payload.frames, lengthAdded = p.framer.AppendStreamFrames(payload.frames, maxFrameSize-payload.length) - payload.length += lengthAdded - } - return payload -} - -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) { - var hdr *wire.ExtendedHeader - var payload *payload - var sealer sealer - //nolint:exhaustive // Probe packets are never sent for 0-RTT. - switch encLevel { - case protocol.EncryptionInitial: - var err error - sealer, err = p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, err - } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) - case protocol.EncryptionHandshake: - var err error - sealer, err = p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, err - } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) - case protocol.Encryption1RTT: - oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - sealer = oneRTTSealer - hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) - default: - panic("unknown encryption level") - } - if payload == nil { - return nil, nil - } - size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - var padding protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - padding = p.initialPaddingLen(payload.frames, size) - } - buffer := getPacketBuffer() - cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: cont, - }, nil -} - -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) { - payload := &payload{ - frames: []ackhandler.Frame{ping}, - length: ping.Length(p.version), - } - buffer := getPacketBuffer() - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) - contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true) - if err != nil { - return nil, err - } - contents.isMTUProbePacket = true - return &packedPacket{ - buffer: buffer, - packetContents: contents, - }, nil -} - -func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { - switch encLevel { - case protocol.EncryptionInitial: - sealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionInitial) - return sealer, hdr, nil - case protocol.Encryption0RTT: - sealer, err := p.cryptoSetup.Get0RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.Encryption0RTT) - return sealer, hdr, nil - case protocol.EncryptionHandshake: - sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionHandshake) - return sealer, hdr, nil - case protocol.Encryption1RTT: - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - return sealer, hdr, nil - default: - return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) - } -} - -func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) - hdr := &wire.ExtendedHeader{} - hdr.PacketNumber = pn - hdr.PacketNumberLen = pnLen - hdr.DestConnectionID = p.getDestConnID() - hdr.KeyPhase = kp - return hdr -} - -func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) - hdr := &wire.ExtendedHeader{ - PacketNumber: pn, - PacketNumberLen: pnLen, - } - hdr.IsLongHeader = true - hdr.Version = p.version - hdr.SrcConnectionID = p.srcConnID - hdr.DestConnectionID = p.getDestConnID() - - //nolint:exhaustive // 1-RTT packets are not long header packets. - switch encLevel { - case protocol.EncryptionInitial: - hdr.Type = protocol.PacketTypeInitial - hdr.Token = p.token - case protocol.EncryptionHandshake: - hdr.Type = protocol.PacketTypeHandshake - case protocol.Encryption0RTT: - hdr.Type = protocol.PacketType0RTT - } - return hdr -} - -// writeSinglePacket packs a single packet. -func (p *packetPacker) writeSinglePacket( - hdr *wire.ExtendedHeader, - payload *payload, - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packedPacket, error) { - buffer := getPacketBuffer() - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) - } - contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: contents, - }, nil -} - -func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { - var paddingLen protocol.ByteCount - pnLen := protocol.ByteCount(header.PacketNumberLen) - if payload.length < 4-pnLen { - paddingLen = 4 - pnLen - payload.length - } - paddingLen += padding - if header.IsLongHeader { - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen - } - - hdrOffset := buffer.Len() - buf := bytes.NewBuffer(buffer.Data) - if err := header.Write(buf, p.version); err != nil { - return nil, err - } - payloadOffset := buf.Len() - - if payload.ack != nil { - if err := payload.ack.Write(buf, p.version); err != nil { - return nil, err - } - } - if paddingLen > 0 { - buf.Write(make([]byte, paddingLen)) - } - for _, frame := range payload.frames { - if err := frame.Write(buf, p.version); err != nil { - return nil, err - } - } - - if payloadSize := protocol.ByteCount(buf.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) - } - if !isMTUProbePacket { - if size := protocol.ByteCount(buf.Len() + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) - } - } - - raw := buffer.Data - // encrypt the packet - raw = raw[:buf.Len()] - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] - // apply header protection - pnOffset := payloadOffset - int(header.PacketNumberLen) - sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset]) - buffer.Data = raw - - num := p.pnManager.PopPacketNumber(encLevel) - if num != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - return &packetContents{ - header: header, - ack: payload.ack, - frames: payload.frames, - length: buffer.Len() - hdrOffset, - }, nil -} - -func (p *packetPacker) SetToken(token []byte) { - p.token = token -} - -// When a higher MTU is discovered, use it. -func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { - p.maxPacketSize = s -} - -// If the peer sets a max_packet_size that's smaller than the size we're currently using, -// we need to reduce the size of packets we send. -func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { - if params.MaxUDPPayloadSize != 0 { - p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxUDPPayloadSize) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/server.go b/vendor/github.com/lucas-clemente/quic-go/server.go deleted file mode 100644 index 7e686881d80..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/server.go +++ /dev/null @@ -1,670 +0,0 @@ -package quic - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "errors" - "fmt" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" -) - -// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. -var ErrServerClosed = errors.New("quic: Server closed") - -// packetHandler handles packets -type packetHandler interface { - handlePacket(*receivedPacket) - shutdown() - destroy(error) - getPerspective() protocol.Perspective -} - -type unknownPacketHandler interface { - handlePacket(*receivedPacket) - setCloseError(error) -} - -type packetHandlerManager interface { - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool - Destroy() error - connRunner - SetServer(unknownPacketHandler) - CloseServer() -} - -type quicConn interface { - EarlyConnection - earlyConnReady() <-chan struct{} - handlePacket(*receivedPacket) - GetVersion() protocol.VersionNumber - getPerspective() protocol.Perspective - run() error - destroy(error) - shutdown() -} - -// A Listener of QUIC -type baseServer struct { - mutex sync.Mutex - - acceptEarlyConns bool - - tlsConf *tls.Config - config *Config - - conn rawConn - // If the server is started with ListenAddr, we create a packet conn. - // If it is started with Listen, we take a packet conn as a parameter. - createdPacketConn bool - - tokenGenerator *handshake.TokenGenerator - - connHandler packetHandlerManager - - receivedPackets chan *receivedPacket - - // set as a member, so they can be set in the tests - newConn func( - sendConn, - connRunner, - protocol.ConnectionID, /* original dest connection ID */ - *protocol.ConnectionID, /* retry src connection ID */ - protocol.ConnectionID, /* client dest connection ID */ - protocol.ConnectionID, /* destination connection ID */ - protocol.ConnectionID, /* source connection ID */ - protocol.StatelessResetToken, - *Config, - *tls.Config, - *handshake.TokenGenerator, - bool, /* enable 0-RTT */ - logging.ConnectionTracer, - uint64, - utils.Logger, - protocol.VersionNumber, - ) quicConn - - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns - - connQueue chan quicConn - connQueueLen int32 // to be used as an atomic - - logger utils.Logger -} - -var ( - _ Listener = &baseServer{} - _ unknownPacketHandler = &baseServer{} -) - -type earlyServer struct{ *baseServer } - -var _ EarlyListener = &earlyServer{} - -func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { - return s.baseServer.accept(ctx) -} - -// ListenAddr creates a QUIC server listening on a given address. -// The tls.Config must not be nil and must contain a certificate configuration. -// The quic.Config may be nil, in that case the default values will be used. -func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { - return listenAddr(addr, tlsConf, config, false) -} - -// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. -func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listenAddr(addr, tlsConf, config, true) - if err != nil { - return nil, err - } - return &earlyServer{s}, nil -} - -func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - serv, err := listen(conn, tlsConf, config, acceptEarly) - if err != nil { - return nil, err - } - serv.createdPacketConn = true - return serv, nil -} - -// Listen listens for QUIC connections on a given net.PacketConn. If the -// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. A single net.PacketConn only be used for a single call to Listen. -// The PacketConn can be used for simultaneous calls to Dial. QUIC connection -// IDs are used for demultiplexing the different connections. The tls.Config -// must not be nil and must contain a certificate configuration. The -// tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. Furthermore, -// it must define an application control (using NextProtos). The quic.Config may -// be nil, in that case the default values will be used. -func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { - return listen(conn, tlsConf, config, false) -} - -// ListenEarly works like Listen, but it returns connections before the handshake completes. -func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listen(conn, tlsConf, config, true) - if err != nil { - return nil, err - } - return &earlyServer{s}, nil -} - -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateServerConfig(config) - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } - tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) - if err != nil { - return nil, err - } - c, err := wrapConn(conn) - if err != nil { - return nil, err - } - s := &baseServer{ - conn: c, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newConn: newConnection, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - } - go s.run() - connHandler.SetServer(s) - s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - return s, nil -} - -func (s *baseServer) run() { - defer close(s.running) - for { - select { - case <-s.errorChan: - return - default: - } - select { - case <-s.errorChan: - return - case p := <-s.receivedPackets: - if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { - p.buffer.Release() - } - } - } -} - -var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { - if token == nil { - return false - } - validity := protocol.TokenValidity - if token.IsRetryToken { - validity = protocol.RetryTokenValidity - } - if time.Now().After(token.SentTime.Add(validity)) { - return false - } - var sourceAddr string - if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { - sourceAddr = udpAddr.IP.String() - } else { - sourceAddr = clientAddr.String() - } - return sourceAddr == token.RemoteAddr -} - -// Accept returns connections that already completed the handshake. -// It is only valid if acceptEarlyConns is false. -func (s *baseServer) Accept(ctx context.Context) (Connection, error) { - return s.accept(ctx) -} - -func (s *baseServer) accept(ctx context.Context) (quicConn, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-s.connQueue: - atomic.AddInt32(&s.connQueueLen, -1) - return conn, nil - case <-s.errorChan: - return nil, s.serverError - } -} - -// Close the server -func (s *baseServer) Close() error { - s.mutex.Lock() - if s.closed { - s.mutex.Unlock() - return nil - } - if s.serverError == nil { - s.serverError = ErrServerClosed - } - // If the server was started with ListenAddr, we created the packet conn. - // We need to close it in order to make the go routine reading from that conn return. - createdPacketConn := s.createdPacketConn - s.closed = true - close(s.errorChan) - s.mutex.Unlock() - - <-s.running - s.connHandler.CloseServer() - if createdPacketConn { - return s.connHandler.Destroy() - } - return nil -} - -func (s *baseServer) setCloseError(e error) { - s.mutex.Lock() - defer s.mutex.Unlock() - if s.closed { - return - } - s.closed = true - s.serverError = e - close(s.errorChan) -} - -// Addr returns the server's network address -func (s *baseServer) Addr() net.Addr { - return s.conn.LocalAddr() -} - -func (s *baseServer) handlePacket(p *receivedPacket) { - select { - case s.receivedPackets <- p: - default: - s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) - } - } -} - -func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { - if wire.IsVersionNegotiationPacket(p.data) { - s.logger.Debugf("Dropping Version Negotiation packet.") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - // If we're creating a new connection, the packet will be passed to the connection. - // The header will then be parsed again. - hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen()) - if err != nil && err != wire.ErrUnsupportedVersion { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Error parsing packet: %s", err) - return false - } - // Short header packets should never end up here in the first place - if !hdr.IsLongHeader { - panic(fmt.Sprintf("misrouted packet: %#v", hdr)) - } - if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { - s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - if p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p, hdr) - } - return false - } - if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { - // Drop long header packets. - // There's little point in sending a Stateless Reset, since the client - // might not have received the token yet. - s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - - s.logger.Debugf("<- Received Initial packet.") - - if err := s.handleInitialImpl(p, hdr); err != nil { - s.logger.Errorf("Error occurred handling initial packet: %s", err) - } - // Don't put the packet buffer back. - // handleInitialImpl deals with the buffer. - return true -} - -func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { - if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - p.buffer.Release() - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) - } - return errors.New("too short connection ID") - } - - var ( - token *Token - retrySrcConnID *protocol.ConnectionID - ) - origDestConnID := hdr.DestConnectionID - if len(hdr.Token) > 0 { - c, err := s.tokenGenerator.DecodeToken(hdr.Token) - if err == nil { - token = &Token{ - IsRetryToken: c.IsRetryToken, - RemoteAddr: c.RemoteAddr, - SentTime: c.SentTime, - } - if token.IsRetryToken { - origDestConnID = c.OriginalDestConnectionID - retrySrcConnID = &c.RetrySrcConnectionID - } - } - } - if !s.config.AcceptToken(p.remoteAddr, token) { - go func() { - defer p.buffer.Release() - if token != nil && token.IsRetryToken { - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - return - } - if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error sending Retry: %s", err) - } - }() - return nil - } - - if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { - s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) - go func() { - defer p.buffer.Release() - if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { - s.logger.Debugf("Error rejecting connection: %s", err) - } - }() - return nil - } - - connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() - if err != nil { - return err - } - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID)) - var conn quicConn - tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { - var tracer logging.ConnectionTracer - if s.config.Tracer != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := hdr.DestConnectionID - if origDestConnID.Len() > 0 { - connID = origDestConnID - } - tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), ConnectionTracingKey, tracingID), - protocol.PerspectiveServer, - connID, - ) - } - conn = s.newConn( - newSendConn(s.conn, p.remoteAddr, p.info), - s.connHandler, - origDestConnID, - retrySrcConnID, - hdr.DestConnectionID, - hdr.SrcConnectionID, - connID, - s.connHandler.GetStatelessResetToken(connID), - s.config, - s.tlsConf, - s.tokenGenerator, - s.acceptEarlyConns, - tracer, - tracingID, - s.logger, - hdr.Version, - ) - conn.handlePacket(p) - return conn - }); !added { - return nil - } - go conn.run() - go s.handleNewConn(conn) - if conn == nil { - p.buffer.Release() - return nil - } - return nil -} - -func (s *baseServer) handleNewConn(conn quicConn) { - connCtx := conn.Context() - if s.acceptEarlyConns { - // wait until the early connection is ready (or the handshake fails) - select { - case <-conn.earlyConnReady(): - case <-connCtx.Done(): - return - } - } else { - // wait until the handshake is complete (or fails) - select { - case <-conn.HandshakeComplete().Done(): - case <-connCtx.Done(): - return - } - } - - atomic.AddInt32(&s.connQueueLen, 1) - select { - case s.connQueue <- conn: - // blocks until the connection is accepted - case <-connCtx.Done(): - atomic.AddInt32(&s.connQueueLen, -1) - // don't pass connections that were already closed to Accept() - } -} - -func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { - // Log the Initial packet now. - // If no Retry is sent, the packet will be logged by the connection. - (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() - if err != nil { - return err - } - token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) - if err != nil { - return err - } - replyHdr := &wire.ExtendedHeader{} - replyHdr.IsLongHeader = true - replyHdr.Type = protocol.PacketTypeRetry - replyHdr.Version = hdr.Version - replyHdr.SrcConnectionID = srcConnID - replyHdr.DestConnectionID = hdr.SrcConnectionID - replyHdr.Token = token - if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID)) - s.logger.Debugf("-> Sending Retry") - replyHdr.Log(s.logger) - } - - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) - if err := replyHdr.Write(buf, hdr.Version); err != nil { - return err - } - // append the Retry integrity tag - tag := handshake.GetRetryIntegrityTag(buf.Bytes(), hdr.DestConnectionID, hdr.Version) - buf.Write(tag[:]) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(buf.Len()), nil) - } - _, err = s.conn.WritePacket(buf.Bytes(), remoteAddr, info.OOB()) - return err -} - -func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { - // Only send INVALID_TOKEN if we can unprotect the packet. - // This makes sure that we won't send it for packets that were corrupted. - sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) - data := p.data[:hdr.ParsedLen()+hdr.Length] - extHdr, err := unpackHeader(opener, hdr, data, hdr.Version) - if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) - } - // don't return the error here. Just drop the packet. - return nil - } - hdrLen := extHdr.ParsedLen() - if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - // don't return the error here. Just drop the packet. - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) - } - return nil - } - if s.logger.Debug() { - s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) - } - return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) -} - -func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { - sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) - return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) -} - -// sendError sends the error as a response to the packet received with header hdr -func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { - packetBuffer := getPacketBuffer() - defer packetBuffer.Release() - buf := bytes.NewBuffer(packetBuffer.Data) - - ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} - - replyHdr := &wire.ExtendedHeader{} - replyHdr.IsLongHeader = true - replyHdr.Type = protocol.PacketTypeInitial - replyHdr.Version = hdr.Version - replyHdr.SrcConnectionID = hdr.DestConnectionID - replyHdr.DestConnectionID = hdr.SrcConnectionID - replyHdr.PacketNumberLen = protocol.PacketNumberLen4 - replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) - if err := replyHdr.Write(buf, hdr.Version); err != nil { - return err - } - payloadOffset := buf.Len() - - if err := ccf.Write(buf, hdr.Version); err != nil { - return err - } - - raw := buf.Bytes() - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], replyHdr.PacketNumber, raw[:payloadOffset]) - raw = raw[0 : buf.Len()+sealer.Overhead()] - - pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) - sealer.EncryptHeader( - raw[pnOffset+4:pnOffset+4+16], - &raw[0], - raw[pnOffset:payloadOffset], - ) - - replyHdr.Log(s.logger) - wire.LogFrame(s.logger, ccf, true) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(raw)), []logging.Frame{ccf}) - } - _, err := s.conn.WritePacket(raw, remoteAddr, info.OOB()) - return err -} - -func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { - s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket( - p.remoteAddr, - &wire.Header{ - IsLongHeader: true, - DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, - }, - protocol.ByteCount(len(data)), - nil, - ) - } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - s.logger.Debugf("Error sending Version Negotiation: %s", err) - } -} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go deleted file mode 100644 index 26b562331c3..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_generic_helper.go +++ /dev/null @@ -1,18 +0,0 @@ -package quic - -import ( - "github.com/cheekybits/genny/generic" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// In the auto-generated streams maps, we need to be able to close the streams. -// Therefore, extend the generic.Type with the stream close method. -// This definition must be in a file that Genny doesn't process. -type item interface { - generic.Type - updateSendWindow(protocol.ByteCount) - closeForShutdown(error) -} - -const streamTypeGeneric protocol.StreamType = protocol.StreamTypeUni diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go deleted file mode 100644 index 46c8c73a089..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_bidi.go +++ /dev/null @@ -1,192 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -// When a stream is deleted before it was accepted, we can't delete it from the map immediately. -// We need to wait until the application accepts it, and delete it then. -type streamIEntry struct { - stream streamI - shouldDelete bool -} - -type incomingBidiStreamsMap struct { - mutex sync.RWMutex - newStreamChan chan struct{} - - streams map[protocol.StreamNum]streamIEntry - - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams - - newStream func(protocol.StreamNum) streamI - queueMaxStreamID func(*wire.MaxStreamsFrame) - - closeErr error -} - -func newIncomingBidiStreamsMap( - newStream func(protocol.StreamNum) streamI, - maxStreams uint64, - queueControlFrame func(wire.Frame), -) *incomingBidiStreamsMap { - return &incomingBidiStreamsMap{ - newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]streamIEntry), - maxStream: protocol.StreamNum(maxStreams), - maxNumStreams: maxStreams, - newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, - queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - } -} - -func (m *incomingBidiStreamsMap) AcceptStream(ctx context.Context) (streamI, error) { - // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist - select { - case <-m.newStreamChan: - default: - } - - m.mutex.Lock() - - var num protocol.StreamNum - var entry streamIEntry - for { - num = m.nextStreamToAccept - if m.closeErr != nil { - m.mutex.Unlock() - return nil, m.closeErr - } - var ok bool - entry, ok = m.streams[num] - if ok { - break - } - m.mutex.Unlock() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.newStreamChan: - } - m.mutex.Lock() - } - m.nextStreamToAccept++ - // If this stream was completed before being accepted, we can delete it now. - if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { - m.mutex.Unlock() - return nil, err - } - } - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingBidiStreamsMap) GetOrOpenStream(num protocol.StreamNum) (streamI, error) { - m.mutex.RLock() - if num > m.maxStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } - } - // if the num is smaller than the highest we accepted - // * this stream exists in the map, and we can return it, or - // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { - var s streamI - // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { - s = entry.stream - } - m.mutex.RUnlock() - return s, nil - } - m.mutex.RUnlock() - - m.mutex.Lock() - // no need to check the two error conditions from above again - // * maxStream can only increase, so if the id was valid before, it definitely is valid now - // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = streamIEntry{stream: m.newStream(newNum)} - select { - case m.newStreamChan <- struct{}{}: - default: - } - } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - return m.deleteStream(num) -} - -func (m *incomingBidiStreamsMap) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } - } - - // Don't delete this stream yet, if it was not yet accepted. - // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] - if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } - } - entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign - return nil - } - - delete(m.streams, num) - // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream - if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { - m.maxStream = maxStream - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: m.maxStream, - }) - } - } - return nil -} - -func (m *incomingBidiStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, entry := range m.streams { - entry.stream.closeForShutdown(err) - } - m.mutex.Unlock() - close(m.newStreamChan) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go deleted file mode 100644 index 4c7696a0897..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_generic.go +++ /dev/null @@ -1,190 +0,0 @@ -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -// When a stream is deleted before it was accepted, we can't delete it from the map immediately. -// We need to wait until the application accepts it, and delete it then. -type itemEntry struct { - stream item - shouldDelete bool -} - -//go:generate genny -in $GOFILE -out streams_map_incoming_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" -//go:generate genny -in $GOFILE -out streams_map_incoming_uni.go gen "item=receiveStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" -type incomingItemsMap struct { - mutex sync.RWMutex - newStreamChan chan struct{} - - streams map[protocol.StreamNum]itemEntry - - nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() - nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened - maxStream protocol.StreamNum // the highest stream that the peer is allowed to open - maxNumStreams uint64 // maximum number of streams - - newStream func(protocol.StreamNum) item - queueMaxStreamID func(*wire.MaxStreamsFrame) - - closeErr error -} - -func newIncomingItemsMap( - newStream func(protocol.StreamNum) item, - maxStreams uint64, - queueControlFrame func(wire.Frame), -) *incomingItemsMap { - return &incomingItemsMap{ - newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]itemEntry), - maxStream: protocol.StreamNum(maxStreams), - maxNumStreams: maxStreams, - newStream: newStream, - nextStreamToOpen: 1, - nextStreamToAccept: 1, - queueMaxStreamID: func(f *wire.MaxStreamsFrame) { queueControlFrame(f) }, - } -} - -func (m *incomingItemsMap) AcceptStream(ctx context.Context) (item, error) { - // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist - select { - case <-m.newStreamChan: - default: - } - - m.mutex.Lock() - - var num protocol.StreamNum - var entry itemEntry - for { - num = m.nextStreamToAccept - if m.closeErr != nil { - m.mutex.Unlock() - return nil, m.closeErr - } - var ok bool - entry, ok = m.streams[num] - if ok { - break - } - m.mutex.Unlock() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-m.newStreamChan: - } - m.mutex.Lock() - } - m.nextStreamToAccept++ - // If this stream was completed before being accepted, we can delete it now. - if entry.shouldDelete { - if err := m.deleteStream(num); err != nil { - m.mutex.Unlock() - return nil, err - } - } - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingItemsMap) GetOrOpenStream(num protocol.StreamNum) (item, error) { - m.mutex.RLock() - if num > m.maxStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer tried to open stream %d (current limit: %d)", - nums: []protocol.StreamNum{num, m.maxStream}, - } - } - // if the num is smaller than the highest we accepted - // * this stream exists in the map, and we can return it, or - // * this stream was already closed, then we can return the nil - if num < m.nextStreamToOpen { - var s item - // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. - if entry, ok := m.streams[num]; ok && !entry.shouldDelete { - s = entry.stream - } - m.mutex.RUnlock() - return s, nil - } - m.mutex.RUnlock() - - m.mutex.Lock() - // no need to check the two error conditions from above again - // * maxStream can only increase, so if the id was valid before, it definitely is valid now - // * highestStream is only modified by this function - for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = itemEntry{stream: m.newStream(newNum)} - select { - case m.newStreamChan <- struct{}{}: - default: - } - } - m.nextStreamToOpen = num + 1 - entry := m.streams[num] - m.mutex.Unlock() - return entry.stream, nil -} - -func (m *incomingItemsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - return m.deleteStream(num) -} - -func (m *incomingItemsMap) deleteStream(num protocol.StreamNum) error { - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown incoming stream %d", - nums: []protocol.StreamNum{num}, - } - } - - // Don't delete this stream yet, if it was not yet accepted. - // Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted. - if num >= m.nextStreamToAccept { - entry, ok := m.streams[num] - if ok && entry.shouldDelete { - return streamError{ - message: "tried to delete incoming stream %d multiple times", - nums: []protocol.StreamNum{num}, - } - } - entry.shouldDelete = true - m.streams[num] = entry // can't assign to struct in map, so we need to reassign - return nil - } - - delete(m.streams, num) - // queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream - if m.maxNumStreams > uint64(len(m.streams)) { - maxStream := m.nextStreamToOpen + protocol.StreamNum(m.maxNumStreams-uint64(len(m.streams))) - 1 - // Never send a value larger than protocol.MaxStreamCount. - if maxStream <= protocol.MaxStreamCount { - m.maxStream = maxStream - m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: streamTypeGeneric, - MaxStreamNum: m.maxStream, - }) - } - } - return nil -} - -func (m *incomingItemsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, entry := range m.streams { - entry.stream.closeForShutdown(err) - } - m.mutex.Unlock() - close(m.newStreamChan) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go deleted file mode 100644 index 3f7ec166ad1..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_bidi.go +++ /dev/null @@ -1,226 +0,0 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -type outgoingBidiStreamsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]streamI - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) streamI - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingBidiStreamsMap( - newStream func(protocol.StreamNum) streamI, - queueControlFrame func(wire.Frame), -) *outgoingBidiStreamsMap { - return &outgoingBidiStreamsMap{ - streams: make(map[protocol.StreamNum]streamI), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingBidiStreamsMap) OpenStream() (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingBidiStreamsMap) OpenStreamSync(ctx context.Context) (streamI, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingBidiStreamsMap) openStream() streamI { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingBidiStreamsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeBidi, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingBidiStreamsMap) GetStream(num protocol.StreamNum) (streamI, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingBidiStreamsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingBidiStreamsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingBidiStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingBidiStreamsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingBidiStreamsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go b/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go deleted file mode 100644 index dde75043c2d..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_generic.go +++ /dev/null @@ -1,224 +0,0 @@ -package quic - -import ( - "context" - "sync" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" -) - -//go:generate genny -in $GOFILE -out streams_map_outgoing_bidi.go gen "item=streamI Item=BidiStream streamTypeGeneric=protocol.StreamTypeBidi" -//go:generate genny -in $GOFILE -out streams_map_outgoing_uni.go gen "item=sendStreamI Item=UniStream streamTypeGeneric=protocol.StreamTypeUni" -type outgoingItemsMap struct { - mutex sync.RWMutex - - streams map[protocol.StreamNum]item - - openQueue map[uint64]chan struct{} - lowestInQueue uint64 - highestInQueue uint64 - - nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) - maxStream protocol.StreamNum // the maximum stream ID we're allowed to open - blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - - newStream func(protocol.StreamNum) item - queueStreamIDBlocked func(*wire.StreamsBlockedFrame) - - closeErr error -} - -func newOutgoingItemsMap( - newStream func(protocol.StreamNum) item, - queueControlFrame func(wire.Frame), -) *outgoingItemsMap { - return &outgoingItemsMap{ - streams: make(map[protocol.StreamNum]item), - openQueue: make(map[uint64]chan struct{}), - maxStream: protocol.InvalidStreamNum, - nextStream: 1, - newStream: newStream, - queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, - } -} - -func (m *outgoingItemsMap) OpenStream() (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - // if there are OpenStreamSync calls waiting, return an error here - if len(m.openQueue) > 0 || m.nextStream > m.maxStream { - m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} - } - return m.openStream(), nil -} - -func (m *outgoingItemsMap) OpenStreamSync(ctx context.Context) (item, error) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.closeErr != nil { - return nil, m.closeErr - } - - if err := ctx.Err(); err != nil { - return nil, err - } - - if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { - return m.openStream(), nil - } - - waitChan := make(chan struct{}, 1) - queuePos := m.highestInQueue - m.highestInQueue++ - if len(m.openQueue) == 0 { - m.lowestInQueue = queuePos - } - m.openQueue[queuePos] = waitChan - m.maybeSendBlockedFrame() - - for { - m.mutex.Unlock() - select { - case <-ctx.Done(): - m.mutex.Lock() - delete(m.openQueue, queuePos) - return nil, ctx.Err() - case <-waitChan: - } - m.mutex.Lock() - - if m.closeErr != nil { - return nil, m.closeErr - } - if m.nextStream > m.maxStream { - // no stream available. Continue waiting - continue - } - str := m.openStream() - delete(m.openQueue, queuePos) - m.lowestInQueue = queuePos + 1 - m.unblockOpenSync() - return str, nil - } -} - -func (m *outgoingItemsMap) openStream() item { - s := m.newStream(m.nextStream) - m.streams[m.nextStream] = s - m.nextStream++ - return s -} - -// maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, -// if we haven't sent one for this offset yet -func (m *outgoingItemsMap) maybeSendBlockedFrame() { - if m.blockedSent { - return - } - - var streamNum protocol.StreamNum - if m.maxStream != protocol.InvalidStreamNum { - streamNum = m.maxStream - } - m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: streamTypeGeneric, - StreamLimit: streamNum, - }) - m.blockedSent = true -} - -func (m *outgoingItemsMap) GetStream(num protocol.StreamNum) (item, error) { - m.mutex.RLock() - if num >= m.nextStream { - m.mutex.RUnlock() - return nil, streamError{ - message: "peer attempted to open stream %d", - nums: []protocol.StreamNum{num}, - } - } - s := m.streams[num] - m.mutex.RUnlock() - return s, nil -} - -func (m *outgoingItemsMap) DeleteStream(num protocol.StreamNum) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - if _, ok := m.streams[num]; !ok { - return streamError{ - message: "tried to delete unknown outgoing stream %d", - nums: []protocol.StreamNum{num}, - } - } - delete(m.streams, num) - return nil -} - -func (m *outgoingItemsMap) SetMaxStream(num protocol.StreamNum) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if num <= m.maxStream { - return - } - m.maxStream = num - m.blockedSent = false - if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { - m.maybeSendBlockedFrame() - } - m.unblockOpenSync() -} - -// UpdateSendWindow is called when the peer's transport parameters are received. -// Only in the case of a 0-RTT handshake will we have open streams at this point. -// We might need to update the send window, in case the server increased it. -func (m *outgoingItemsMap) UpdateSendWindow(limit protocol.ByteCount) { - m.mutex.Lock() - for _, str := range m.streams { - str.updateSendWindow(limit) - } - m.mutex.Unlock() -} - -// unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingItemsMap) unblockOpenSync() { - if len(m.openQueue) == 0 { - return - } - for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { - c, ok := m.openQueue[qp] - if !ok { // entry was deleted because the context was canceled - continue - } - // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. - // It's sufficient to only unblock OpenStreamSync once. - select { - case c <- struct{}{}: - default: - } - return - } -} - -func (m *outgoingItemsMap) CloseWithError(err error) { - m.mutex.Lock() - m.closeErr = err - for _, str := range m.streams { - str.closeForShutdown(err) - } - for _, c := range m.openQueue { - if c != nil { - close(c) - } - } - m.mutex.Unlock() -} diff --git a/vendor/github.com/lucas-clemente/quic-go/tools.go b/vendor/github.com/lucas-clemente/quic-go/tools.go deleted file mode 100644 index ee68fafbe39..00000000000 --- a/vendor/github.com/lucas-clemente/quic-go/tools.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build tools -// +build tools - -package quic - -import ( - _ "github.com/cheekybits/genny" - _ "github.com/onsi/ginkgo/ginkgo" -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/README.md b/vendor/github.com/marten-seemann/qtls-go1-16/README.md deleted file mode 100644 index 0d318abe4b2..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# qtls - -[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/marten-seemann/qtls) -[![CircleCI Build Status](https://img.shields.io/circleci/project/github/marten-seemann/qtls.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/marten-seemann/qtls) - -This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/auth.go b/vendor/github.com/marten-seemann/qtls-go1-16/auth.go deleted file mode 100644 index 1ef675fd37c..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/auth.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rsa" - "errors" - "fmt" - "hash" - "io" -) - -// verifyHandshakeSignature verifies a signature against pre-hashed -// (if required) handshake contents. -func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { - switch sigType { - case signatureECDSA: - pubKey, ok := pubkey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) - } - if !ecdsa.VerifyASN1(pubKey, signed, sig) { - return errors.New("ECDSA verification failure") - } - case signatureEd25519: - pubKey, ok := pubkey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) - } - if !ed25519.Verify(pubKey, signed, sig) { - return errors.New("Ed25519 verification failure") - } - case signaturePKCS1v15: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { - return err - } - case signatureRSAPSS: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} - if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { - return err - } - default: - return errors.New("internal error: unknown signature type") - } - return nil -} - -const ( - serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" - clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" -) - -var signaturePadding = []byte{ - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, -} - -// signedMessage returns the pre-hashed (if necessary) message to be signed by -// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. -func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { - if sigHash == directSigning { - b := &bytes.Buffer{} - b.Write(signaturePadding) - io.WriteString(b, context) - b.Write(transcript.Sum(nil)) - return b.Bytes() - } - h := sigHash.New() - h.Write(signaturePadding) - io.WriteString(h, context) - h.Write(transcript.Sum(nil)) - return h.Sum(nil) -} - -// typeAndHashFromSignatureScheme returns the corresponding signature type and -// crypto.Hash for a given TLS SignatureScheme. -func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { - switch signatureAlgorithm { - case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: - sigType = signaturePKCS1v15 - case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: - sigType = signatureRSAPSS - case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: - sigType = signatureECDSA - case Ed25519: - sigType = signatureEd25519 - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - switch signatureAlgorithm { - case PKCS1WithSHA1, ECDSAWithSHA1: - hash = crypto.SHA1 - case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: - hash = crypto.SHA256 - case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: - hash = crypto.SHA384 - case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: - hash = crypto.SHA512 - case Ed25519: - hash = directSigning - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - return sigType, hash, nil -} - -// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for -// a given public key used with TLS 1.0 and 1.1, before the introduction of -// signature algorithm negotiation. -func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { - switch pub.(type) { - case *rsa.PublicKey: - return signaturePKCS1v15, crypto.MD5SHA1, nil - case *ecdsa.PublicKey: - return signatureECDSA, crypto.SHA1, nil - case ed25519.PublicKey: - // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, - // but it requires holding on to a handshake transcript to do a - // full signature, and not even OpenSSL bothers with the - // complexity, so we can't even test it properly. - return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") - default: - return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) - } -} - -var rsaSignatureSchemes = []struct { - scheme SignatureScheme - minModulusBytes int - maxVersion uint16 -}{ - // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires - // emLen >= hLen + sLen + 2 - {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, - // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires - // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. - {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, - {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, - {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, - {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, -} - -// signatureSchemesForCertificate returns the list of supported SignatureSchemes -// for a given certificate, based on the public key and the protocol version, -// and optionally filtered by its explicit SupportedSignatureAlgorithms. -// -// This function must be kept in sync with supportedSignatureAlgorithms. -func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil - } - - var sigAlgs []SignatureScheme - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - if version != VersionTLS13 { - // In TLS 1.2 and earlier, ECDSA algorithms are not - // constrained to a single curve. - sigAlgs = []SignatureScheme{ - ECDSAWithP256AndSHA256, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - ECDSAWithSHA1, - } - break - } - switch pub.Curve { - case elliptic.P256(): - sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} - case elliptic.P384(): - sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} - case elliptic.P521(): - sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} - default: - return nil - } - case *rsa.PublicKey: - size := pub.Size() - sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) - for _, candidate := range rsaSignatureSchemes { - if size >= candidate.minModulusBytes && version <= candidate.maxVersion { - sigAlgs = append(sigAlgs, candidate.scheme) - } - } - case ed25519.PublicKey: - sigAlgs = []SignatureScheme{Ed25519} - default: - return nil - } - - if cert.SupportedSignatureAlgorithms != nil { - var filteredSigAlgs []SignatureScheme - for _, sigAlg := range sigAlgs { - if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { - filteredSigAlgs = append(filteredSigAlgs, sigAlg) - } - } - return filteredSigAlgs - } - return sigAlgs -} - -// selectSignatureScheme picks a SignatureScheme from the peer's preference list -// that works with the selected certificate. It's only called for protocol -// versions that support signature algorithms, so TLS 1.2 and 1.3. -func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { - supportedAlgs := signatureSchemesForCertificate(vers, c) - if len(supportedAlgs) == 0 { - return 0, unsupportedCertificateError(c) - } - if len(peerAlgs) == 0 && vers == VersionTLS12 { - // For TLS 1.2, if the client didn't send signature_algorithms then we - // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. - peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} - } - // Pick signature scheme in the peer's preference order, as our - // preference order is not configurable. - for _, preferredAlg := range peerAlgs { - if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { - return preferredAlg, nil - } - } - return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") -} - -// unsupportedCertificateError returns a helpful error for certificates with -// an unsupported private key. -func unsupportedCertificateError(cert *Certificate) error { - switch cert.PrivateKey.(type) { - case rsa.PrivateKey, ecdsa.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", - cert.PrivateKey, cert.PrivateKey) - case *ed25519.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") - } - - signer, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", - cert.PrivateKey) - } - - switch pub := signer.Public().(type) { - case *ecdsa.PublicKey: - switch pub.Curve { - case elliptic.P256(): - case elliptic.P384(): - case elliptic.P521(): - default: - return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) - } - case *rsa.PublicKey: - return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") - case ed25519.PublicKey: - default: - return fmt.Errorf("tls: unsupported certificate key (%T)", pub) - } - - if cert.SupportedSignatureAlgorithms != nil { - return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") - } - - return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go deleted file mode 100644 index 78f80107ff0..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/cipher_suites.go +++ /dev/null @@ -1,532 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/des" - "crypto/hmac" - "crypto/rc4" - "crypto/sha1" - "crypto/sha256" - "crypto/x509" - "fmt" - "hash" - - "golang.org/x/crypto/chacha20poly1305" -) - -// CipherSuite is a TLS cipher suite. Note that most functions in this package -// accept and expose cipher suite IDs instead of this type. -type CipherSuite struct { - ID uint16 - Name string - - // Supported versions is the list of TLS protocol versions that can - // negotiate this cipher suite. - SupportedVersions []uint16 - - // Insecure is true if the cipher suite has known security issues - // due to its primitives, design, or implementation. - Insecure bool -} - -var ( - supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} - supportedOnlyTLS12 = []uint16{VersionTLS12} - supportedOnlyTLS13 = []uint16{VersionTLS13} -) - -// CipherSuites returns a list of cipher suites currently implemented by this -// package, excluding those with security issues, which are returned by -// InsecureCipherSuites. -// -// The list is sorted by ID. Note that the default cipher suites selected by -// this package might depend on logic that can't be captured by a static list. -func CipherSuites() []*CipherSuite { - return []*CipherSuite{ - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - - {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, - {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, - {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, - - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - } -} - -// InsecureCipherSuites returns a list of cipher suites currently implemented by -// this package and which have security issues. -// -// Most applications should not use the cipher suites in this list, and should -// only use those returned by CipherSuites. -func InsecureCipherSuites() []*CipherSuite { - // RC4 suites are broken because RC4 is. - // CBC-SHA256 suites have no Lucky13 countermeasures. - return []*CipherSuite{ - {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - } -} - -// CipherSuiteName returns the standard name for the passed cipher suite ID -// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation -// of the ID value if the cipher suite is not implemented by this package. -func CipherSuiteName(id uint16) string { - for _, c := range CipherSuites() { - if c.ID == id { - return c.Name - } - } - for _, c := range InsecureCipherSuites() { - if c.ID == id { - return c.Name - } - } - return fmt.Sprintf("0x%04X", id) -} - -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. -type keyAgreement interface { - // On the server side, the first two methods are called in order. - - // In the case that the key agreement protocol doesn't use a - // ServerKeyExchange message, generateServerKeyExchange can return nil, - // nil. - generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) - processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) - - // On the client side, the next two methods are called in order. - - // This method may not be called if the server doesn't send a - // ServerKeyExchange message. - processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error - generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) -} - -const ( - // suiteECDHE indicates that the cipher suite involves elliptic curve - // Diffie-Hellman. This means that it should only be selected when the - // client indicates that it supports ECC with a curve and point format - // that we're happy with. - suiteECDHE = 1 << iota - // suiteECSign indicates that the cipher suite involves an ECDSA or - // EdDSA signature and therefore may only be selected when the server's - // certificate is ECDSA or EdDSA. If this is not set then the cipher suite - // is RSA based. - suiteECSign - // suiteTLS12 indicates that the cipher suite should only be advertised - // and accepted when using TLS 1.2. - suiteTLS12 - // suiteSHA384 indicates that the cipher suite uses SHA384 as the - // handshake hash. - suiteSHA384 - // suiteDefaultOff indicates that this cipher suite is not included by - // default. - suiteDefaultOff -) - -// A cipherSuite is a specific combination of key agreement, cipher and MAC function. -type cipherSuite struct { - id uint16 - // the lengths, in bytes, of the key material needed for each component. - keyLen int - macLen int - ivLen int - ka func(version uint16) keyAgreement - // flags is a bitmask of the suite* values, above. - flags int - cipher func(key, iv []byte, isRead bool) interface{} - mac func(key []byte) hash.Hash - aead func(key, fixedNonce []byte) aead -} - -var cipherSuites = []*cipherSuite{ - // Ciphersuite order is chosen so that ECDHE comes before plain RSA and - // AEADs are the top preference. - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12 | suiteDefaultOff, cipherAES, macSHA256, nil}, - {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, - - // RC4-based cipher suites are disabled by default. - {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteDefaultOff, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteDefaultOff, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteDefaultOff, cipherRC4, macSHA1, nil}, -} - -// selectCipherSuite returns the first cipher suite from ids which is also in -// supportedIDs and passes the ok filter. -func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { - for _, id := range ids { - candidate := cipherSuiteByID(id) - if candidate == nil || !ok(candidate) { - continue - } - - for _, suppID := range supportedIDs { - if id == suppID { - return candidate - } - } - } - return nil -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) aead - hash crypto.Hash -} - -type CipherSuiteTLS13 struct { - ID uint16 - KeyLen int - Hash crypto.Hash - AEAD func(key, fixedNonce []byte) cipher.AEAD -} - -func (c *CipherSuiteTLS13) IVLen() int { - return aeadNonceLength -} - -var cipherSuitesTLS13 = []*cipherSuiteTLS13{ - {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, - {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, - {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, -} - -func cipherRC4(key, iv []byte, isRead bool) interface{} { - cipher, _ := rc4.NewCipher(key) - return cipher -} - -func cipher3DES(key, iv []byte, isRead bool) interface{} { - block, _ := des.NewTripleDESCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -func cipherAES(key, iv []byte, isRead bool) interface{} { - block, _ := aes.NewCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -// macSHA1 returns a SHA-1 based constant time MAC. -func macSHA1(key []byte) hash.Hash { - return hmac.New(newConstantTimeHash(sha1.New), key) -} - -// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and -// is currently only used in disabled-by-default cipher suites. -func macSHA256(key []byte) hash.Hash { - return hmac.New(sha256.New, key) -} - -type aead interface { - cipher.AEAD - - // explicitNonceLen returns the number of bytes of explicit nonce - // included in each record. This is eight for older AEADs and - // zero for modern ones. - explicitNonceLen() int -} - -const ( - aeadNonceLength = 12 - noncePrefixLength = 4 -) - -// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to -// each call. -type prefixNonceAEAD struct { - // nonce contains the fixed part of the nonce in the first four bytes. - nonce [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } -func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } - -func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - copy(f.nonce[4:], nonce) - return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) -} - -func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - copy(f.nonce[4:], nonce) - return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) -} - -// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce -// before each call. -type xorNonceAEAD struct { - nonceMask [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number -func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } - -func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result -} - -func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result, err -} - -func aeadAESGCM(key, noncePrefix []byte) aead { - if len(noncePrefix) != noncePrefixLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &prefixNonceAEAD{aead: aead} - copy(ret.nonce[:], noncePrefix) - return ret -} - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return aeadAESGCMTLS13(key, fixedNonce) -} - -func aeadAESGCMTLS13(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -func aeadChaCha20Poly1305(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aead, err := chacha20poly1305.New(key) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -type constantTimeHash interface { - hash.Hash - ConstantTimeSum(b []byte) []byte -} - -// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces -// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. -type cthWrapper struct { - h constantTimeHash -} - -func (c *cthWrapper) Size() int { return c.h.Size() } -func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } -func (c *cthWrapper) Reset() { c.h.Reset() } -func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } -func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } - -func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { - return func() hash.Hash { - return &cthWrapper{h().(constantTimeHash)} - } -} - -// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { - h.Reset() - h.Write(seq) - h.Write(header) - h.Write(data) - res := h.Sum(out) - if extra != nil { - h.Write(extra) - } - return res -} - -func rsaKA(version uint16) keyAgreement { - return rsaKeyAgreement{} -} - -func ecdheECDSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: false, - version: version, - } -} - -func ecdheRSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: true, - version: version, - } -} - -// mutualCipherSuite returns a cipherSuite given a list of supported -// ciphersuites and the id requested by the peer. -func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { - for _, id := range have { - if id == want { - return cipherSuiteByID(id) - } - } - return nil -} - -func cipherSuiteByID(id uint16) *cipherSuite { - for _, cipherSuite := range cipherSuites { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { - for _, id := range have { - if id == want { - return cipherSuiteTLS13ByID(id) - } - } - return nil -} - -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { - for _, cipherSuite := range cipherSuitesTLS13 { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -// A list of cipher suite IDs that are, or have been, implemented by this -// package. -// -// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml -const ( - // TLS 1.0 - 1.2 cipher suites. - TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a - TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f - TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c - TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c - TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a - TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 - - // TLS 1.3 cipher suites. - TLS_AES_128_GCM_SHA256 uint16 = 0x1301 - TLS_AES_256_GCM_SHA384 uint16 = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 - - // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator - // that the client is doing version fallback. See RFC 7507. - TLS_FALLBACK_SCSV uint16 = 0x5600 - - // Legacy names for the corresponding cipher suites with the correct _SHA256 - // suffix, retained for backward compatibility. - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common.go b/vendor/github.com/marten-seemann/qtls-go1-16/common.go deleted file mode 100644 index 266f93fef8b..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/common.go +++ /dev/null @@ -1,1576 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "container/list" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/sha512" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "net" - "sort" - "strings" - "sync" - "time" -) - -const ( - VersionTLS10 = 0x0301 - VersionTLS11 = 0x0302 - VersionTLS12 = 0x0303 - VersionTLS13 = 0x0304 - - // Deprecated: SSLv3 is cryptographically broken, and is no longer - // supported by this package. See golang.org/issue/32716. - VersionSSL30 = 0x0300 -) - -const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) - maxUselessRecords = 16 // maximum number of consecutive non-advancing records -) - -// TLS record types. -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -// TLS handshake message types. -const ( - typeHelloRequest uint8 = 0 - typeClientHello uint8 = 1 - typeServerHello uint8 = 2 - typeNewSessionTicket uint8 = 4 - typeEndOfEarlyData uint8 = 5 - typeEncryptedExtensions uint8 = 8 - typeCertificate uint8 = 11 - typeServerKeyExchange uint8 = 12 - typeCertificateRequest uint8 = 13 - typeServerHelloDone uint8 = 14 - typeCertificateVerify uint8 = 15 - typeClientKeyExchange uint8 = 16 - typeFinished uint8 = 20 - typeCertificateStatus uint8 = 22 - typeKeyUpdate uint8 = 24 - typeNextProtocol uint8 = 67 // Not IANA assigned - typeMessageHash uint8 = 254 // synthetic message -) - -// TLS compression types. -const ( - compressionNone uint8 = 0 -) - -type Extension struct { - Type uint16 - Data []byte -} - -// TLS extension numbers -const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 - extensionSessionTicket uint16 = 35 - extensionPreSharedKey uint16 = 41 - extensionEarlyData uint16 = 42 - extensionSupportedVersions uint16 = 43 - extensionCookie uint16 = 44 - extensionPSKModes uint16 = 45 - extensionCertificateAuthorities uint16 = 47 - extensionSignatureAlgorithmsCert uint16 = 50 - extensionKeyShare uint16 = 51 - extensionRenegotiationInfo uint16 = 0xff01 -) - -// TLS signaling cipher suite values -const ( - scsvRenegotiation uint16 = 0x00ff -) - -type EncryptionLevel uint8 - -const ( - EncryptionHandshake EncryptionLevel = iota - Encryption0RTT - EncryptionApplication -) - -// CurveID is a tls.CurveID -type CurveID = tls.CurveID - -const ( - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - X25519 CurveID = 29 -) - -// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. -type keyShare struct { - group CurveID - data []byte -} - -// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. -const ( - pskModePlain uint8 = 0 - pskModeDHE uint8 = 1 -) - -// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved -// session. See RFC 8446, Section 4.2.11. -type pskIdentity struct { - label []byte - obfuscatedTicketAge uint32 -} - -// TLS Elliptic Curve Point Formats -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 -const ( - pointFormatUncompressed uint8 = 0 -) - -// TLS CertificateStatusType (RFC 3546) -const ( - statusTypeOCSP uint8 = 1 -) - -// Certificate types (for certificateRequestMsg) -const ( - certTypeRSASign = 1 - certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. -) - -// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with -// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. -const ( - signaturePKCS1v15 uint8 = iota + 225 - signatureRSAPSS - signatureECDSA - signatureEd25519 -) - -// directSigning is a standard Hash value that signals that no pre-hashing -// should be performed, and that the input should be signed directly. It is the -// hash function associated with the Ed25519 signature scheme. -var directSigning crypto.Hash = 0 - -// supportedSignatureAlgorithms contains the signature and hash algorithms that -// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ -// CertificateRequest. The two fields are merged to match with TLS 1.3. -// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var supportedSignatureAlgorithms = []SignatureScheme{ - PSSWithSHA256, - ECDSAWithP256AndSHA256, - Ed25519, - PSSWithSHA384, - PSSWithSHA512, - PKCS1WithSHA256, - PKCS1WithSHA384, - PKCS1WithSHA512, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - PKCS1WithSHA1, - ECDSAWithSHA1, -} - -// helloRetryRequestRandom is set as the Random value of a ServerHello -// to signal that the message is actually a HelloRetryRequest. -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -// testingOnlyForceDowngradeCanary is set in tests to force the server side to -// include downgrade canaries even if it's using its highers supported version. -var testingOnlyForceDowngradeCanary bool - -type ConnectionState = tls.ConnectionState - -// ConnectionState records basic TLS details about the connection. -type connectionState struct { - // Version is the TLS version used by the connection (e.g. VersionTLS12). - Version uint16 - - // HandshakeComplete is true if the handshake has concluded. - HandshakeComplete bool - - // DidResume is true if this connection was successfully resumed from a - // previous session with a session ticket or similar mechanism. - DidResume bool - - // CipherSuite is the cipher suite negotiated for the connection (e.g. - // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). - CipherSuite uint16 - - // NegotiatedProtocol is the application protocol negotiated with ALPN. - NegotiatedProtocol string - - // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. - // - // Deprecated: this value is always true. - NegotiatedProtocolIsMutual bool - - // ServerName is the value of the Server Name Indication extension sent by - // the client. It's available both on the server and on the client side. - ServerName string - - // PeerCertificates are the parsed certificates sent by the peer, in the - // order in which they were sent. The first element is the leaf certificate - // that the connection is verified against. - // - // On the client side, it can't be empty. On the server side, it can be - // empty if Config.ClientAuth is not RequireAnyClientCert or - // RequireAndVerifyClientCert. - PeerCertificates []*x509.Certificate - - // VerifiedChains is a list of one or more chains where the first element is - // PeerCertificates[0] and the last element is from Config.RootCAs (on the - // client side) or Config.ClientCAs (on the server side). - // - // On the client side, it's set if Config.InsecureSkipVerify is false. On - // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven - // (and the peer provided a certificate) or RequireAndVerifyClientCert. - VerifiedChains [][]*x509.Certificate - - // SignedCertificateTimestamps is a list of SCTs provided by the peer - // through the TLS handshake for the leaf certificate, if any. - SignedCertificateTimestamps [][]byte - - // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) - // response provided by the peer for the leaf certificate, if any. - OCSPResponse []byte - - // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, - // Section 3). This value will be nil for TLS 1.3 connections and for all - // resumed connections. - // - // Deprecated: there are conditions in which this value might not be unique - // to a connection. See the Security Considerations sections of RFC 5705 and - // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. - TLSUnique []byte - - // ekm is a closure exposed via ExportKeyingMaterial. - ekm func(label string, context []byte, length int) ([]byte, error) -} - -type ConnectionStateWith0RTT struct { - ConnectionState - - Used0RTT bool // true if 0-RTT was both offered and accepted -} - -// ClientAuthType is tls.ClientAuthType -type ClientAuthType = tls.ClientAuthType - -const ( - NoClientCert = tls.NoClientCert - RequestClientCert = tls.RequestClientCert - RequireAnyClientCert = tls.RequireAnyClientCert - VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven - RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert -) - -// requiresClientCert reports whether the ClientAuthType requires a client -// certificate to be provided. -func requiresClientCert(c ClientAuthType) bool { - switch c { - case RequireAnyClientCert, RequireAndVerifyClientCert: - return true - default: - return false - } -} - -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type ClientSessionState = tls.ClientSessionState - -type clientSessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - -// ClientSessionCache is a cache of ClientSessionState objects that can be used -// by a client to resume a TLS session with a given server. ClientSessionCache -// implementations should expect to be called concurrently from different -// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not -// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which -// are supported via this interface. -//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-15 ClientSessionCache" -type ClientSessionCache = tls.ClientSessionCache - -// SignatureScheme is a tls.SignatureScheme -type SignatureScheme = tls.SignatureScheme - -const ( - // RSASSA-PKCS1-v1_5 algorithms. - PKCS1WithSHA256 SignatureScheme = 0x0401 - PKCS1WithSHA384 SignatureScheme = 0x0501 - PKCS1WithSHA512 SignatureScheme = 0x0601 - - // RSASSA-PSS algorithms with public key OID rsaEncryption. - PSSWithSHA256 SignatureScheme = 0x0804 - PSSWithSHA384 SignatureScheme = 0x0805 - PSSWithSHA512 SignatureScheme = 0x0806 - - // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. - ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 - ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 - ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 - - // EdDSA algorithms. - Ed25519 SignatureScheme = 0x0807 - - // Legacy signature and hash algorithms for TLS 1.2. - PKCS1WithSHA1 SignatureScheme = 0x0201 - ECDSAWithSHA1 SignatureScheme = 0x0203 -) - -// ClientHelloInfo contains information from a ClientHello message in order to -// guide application logic in the GetCertificate and GetConfigForClient callbacks. -type ClientHelloInfo = tls.ClientHelloInfo - -type clientHelloInfo struct { - // CipherSuites lists the CipherSuites supported by the client (e.g. - // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). - CipherSuites []uint16 - - // ServerName indicates the name of the server requested by the client - // in order to support virtual hosting. ServerName is only set if the - // client is using SNI (see RFC 4366, Section 3.1). - ServerName string - - // SupportedCurves lists the elliptic curves supported by the client. - // SupportedCurves is set only if the Supported Elliptic Curves - // Extension is being used (see RFC 4492, Section 5.1.1). - SupportedCurves []CurveID - - // SupportedPoints lists the point formats supported by the client. - // SupportedPoints is set only if the Supported Point Formats Extension - // is being used (see RFC 4492, Section 5.1.2). - SupportedPoints []uint8 - - // SignatureSchemes lists the signature and hash schemes that the client - // is willing to verify. SignatureSchemes is set only if the Signature - // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). - SignatureSchemes []SignatureScheme - - // SupportedProtos lists the application protocols supported by the client. - // SupportedProtos is set only if the Application-Layer Protocol - // Negotiation Extension is being used (see RFC 7301, Section 3.1). - // - // Servers can select a protocol by setting Config.NextProtos in a - // GetConfigForClient return value. - SupportedProtos []string - - // SupportedVersions lists the TLS versions supported by the client. - // For TLS versions less than 1.3, this is extrapolated from the max - // version advertised by the client, so values other than the greatest - // might be rejected if used. - SupportedVersions []uint16 - - // Conn is the underlying net.Conn for the connection. Do not read - // from, or write to, this connection; that will cause the TLS - // connection to fail. - Conn net.Conn - - // config is embedded by the GetCertificate or GetConfigForClient caller, - // for use with SupportsCertificate. - config *Config -} - -// CertificateRequestInfo contains information from a server's -// CertificateRequest message, which is used to demand a certificate and proof -// of control from a client. -type CertificateRequestInfo = tls.CertificateRequestInfo - -type certificateRequestInfo struct { - // AcceptableCAs contains zero or more, DER-encoded, X.501 - // Distinguished Names. These are the names of root or intermediate CAs - // that the server wishes the returned certificate to be signed by. An - // empty slice indicates that the server has no preference. - AcceptableCAs [][]byte - - // SignatureSchemes lists the signature schemes that the server is - // willing to verify. - SignatureSchemes []SignatureScheme - - // Version is the TLS version that was negotiated for this connection. - Version uint16 -} - -// RenegotiationSupport enumerates the different levels of support for TLS -// renegotiation. TLS renegotiation is the act of performing subsequent -// handshakes on a connection after the first. This significantly complicates -// the state machine and has been the source of numerous, subtle security -// issues. Initiating a renegotiation is not supported, but support for -// accepting renegotiation requests may be enabled. -// -// Even when enabled, the server may not change its identity between handshakes -// (i.e. the leaf certificate must be the same). Additionally, concurrent -// handshake and application data flow is not permitted so renegotiation can -// only be used with protocols that synchronise with the renegotiation, such as -// HTTPS. -// -// Renegotiation is not defined in TLS 1.3. -type RenegotiationSupport = tls.RenegotiationSupport - -const ( - // RenegotiateNever disables renegotiation. - RenegotiateNever = tls.RenegotiateNever - - // RenegotiateOnceAsClient allows a remote server to request - // renegotiation once per connection. - RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient - - // RenegotiateFreelyAsClient allows a remote server to repeatedly - // request renegotiation. - RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient -) - -// A Config structure is used to configure a TLS client or server. -// After one has been passed to a TLS function it must not be -// modified. A Config may be reused; the tls package will also not -// modify it. -type Config = tls.Config - -type config struct { - // Rand provides the source of entropy for nonces and RSA blinding. - // If Rand is nil, TLS uses the cryptographic random reader in package - // crypto/rand. - // The Reader must be safe for use by multiple goroutines. - Rand io.Reader - - // Time returns the current time as the number of seconds since the epoch. - // If Time is nil, TLS uses time.Now. - Time func() time.Time - - // Certificates contains one or more certificate chains to present to the - // other side of the connection. The first certificate compatible with the - // peer's requirements is selected automatically. - // - // Server configurations must set one of Certificates, GetCertificate or - // GetConfigForClient. Clients doing client-authentication may set either - // Certificates or GetClientCertificate. - // - // Note: if there are multiple Certificates, and they don't have the - // optional field Leaf set, certificate selection will incur a significant - // per-handshake performance cost. - Certificates []Certificate - - // NameToCertificate maps from a certificate name to an element of - // Certificates. Note that a certificate name can be of the form - // '*.example.com' and so doesn't have to be a domain name as such. - // - // Deprecated: NameToCertificate only allows associating a single - // certificate with a given name. Leave this field nil to let the library - // select the first compatible chain from Certificates. - NameToCertificate map[string]*Certificate - - // GetCertificate returns a Certificate based on the given - // ClientHelloInfo. It will only be called if the client supplies SNI - // information or if Certificates is empty. - // - // If GetCertificate is nil or returns nil, then the certificate is - // retrieved from NameToCertificate. If NameToCertificate is nil, the - // best element of Certificates will be used. - GetCertificate func(*ClientHelloInfo) (*Certificate, error) - - // GetClientCertificate, if not nil, is called when a server requests a - // certificate from a client. If set, the contents of Certificates will - // be ignored. - // - // If GetClientCertificate returns an error, the handshake will be - // aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. If - // Certificate.Certificate is empty then no certificate will be sent to - // the server. If this is unacceptable to the server then it may abort - // the handshake. - // - // GetClientCertificate may be called multiple times for the same - // connection if renegotiation occurs or if TLS 1.3 is in use. - GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) - - // GetConfigForClient, if not nil, is called after a ClientHello is - // received from a client. It may return a non-nil Config in order to - // change the Config that will be used to handle this connection. If - // the returned Config is nil, the original Config will be used. The - // Config returned by this callback may not be subsequently modified. - // - // If GetConfigForClient is nil, the Config passed to Server() will be - // used for all connections. - // - // If SessionTicketKey was explicitly set on the returned Config, or if - // SetSessionTicketKeys was called on the returned Config, those keys will - // be used. Otherwise, the original Config keys will be used (and possibly - // rotated if they are automatically managed). - GetConfigForClient func(*ClientHelloInfo) (*Config, error) - - // VerifyPeerCertificate, if not nil, is called after normal - // certificate verification by either a TLS client or server. It - // receives the raw ASN.1 certificates provided by the peer and also - // any verified chains that normal processing found. If it returns a - // non-nil error, the handshake is aborted and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. If normal verification is disabled by - // setting InsecureSkipVerify, or (for a server) when ClientAuth is - // RequestClientCert or RequireAnyClientCert, then this callback will - // be considered but the verifiedChains argument will always be nil. - VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - - // VerifyConnection, if not nil, is called after normal certificate - // verification and after VerifyPeerCertificate by either a TLS client - // or server. If it returns a non-nil error, the handshake is aborted - // and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. This callback will run for all connections - // regardless of InsecureSkipVerify or ClientAuth settings. - VerifyConnection func(ConnectionState) error - - // RootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *x509.CertPool - - // NextProtos is a list of supported application level protocols, in - // order of preference. - NextProtos []string - - // ServerName is used to verify the hostname on the returned - // certificates unless InsecureSkipVerify is given. It is also included - // in the client's handshake to support virtual hosting unless it is - // an IP address. - ServerName string - - // ClientAuth determines the server's policy for - // TLS Client Authentication. The default is NoClientCert. - ClientAuth ClientAuthType - - // ClientCAs defines the set of root certificate authorities - // that servers use if required to verify a client certificate - // by the policy in ClientAuth. - ClientCAs *x509.CertPool - - // InsecureSkipVerify controls whether a client verifies the server's - // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls - // accepts any certificate presented by the server and any host name in that - // certificate. In this mode, TLS is susceptible to machine-in-the-middle - // attacks unless custom verification is used. This should be used only for - // testing or in combination with VerifyConnection or VerifyPeerCertificate. - InsecureSkipVerify bool - - // CipherSuites is a list of supported cipher suites for TLS versions up to - // TLS 1.2. If CipherSuites is nil, a default list of secure cipher suites - // is used, with a preference order based on hardware performance. The - // default cipher suites might change over Go versions. Note that TLS 1.3 - // ciphersuites are not configurable. - CipherSuites []uint16 - - // PreferServerCipherSuites controls whether the server selects the - // client's most preferred ciphersuite, or the server's most preferred - // ciphersuite. If true then the server's preference, as expressed in - // the order of elements in CipherSuites, is used. - PreferServerCipherSuites bool - - // SessionTicketsDisabled may be set to true to disable session ticket and - // PSK (resumption) support. Note that on clients, session ticket support is - // also disabled if ClientSessionCache is nil. - SessionTicketsDisabled bool - - // SessionTicketKey is used by TLS servers to provide session resumption. - // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled - // with random data before the first server handshake. - // - // Deprecated: if this field is left at zero, session ticket keys will be - // automatically rotated every day and dropped after seven days. For - // customizing the rotation schedule or synchronizing servers that are - // terminating connections for the same host, use SetSessionTicketKeys. - SessionTicketKey [32]byte - - // ClientSessionCache is a cache of ClientSessionState entries for TLS - // session resumption. It is only used by clients. - ClientSessionCache ClientSessionCache - - // MinVersion contains the minimum TLS version that is acceptable. - // If zero, TLS 1.0 is currently taken as the minimum. - MinVersion uint16 - - // MaxVersion contains the maximum TLS version that is acceptable. - // If zero, the maximum version supported by this package is used, - // which is currently TLS 1.3. - MaxVersion uint16 - - // CurvePreferences contains the elliptic curves that will be used in - // an ECDHE handshake, in preference order. If empty, the default will - // be used. The client will use the first preference as the type for - // its key share in TLS 1.3. This may change in the future. - CurvePreferences []CurveID - - // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. - // When true, the largest possible TLS record size is always used. When - // false, the size of TLS records may be adjusted in an attempt to - // improve latency. - DynamicRecordSizingDisabled bool - - // Renegotiation controls what types of renegotiation are supported. - // The default, none, is correct for the vast majority of applications. - Renegotiation RenegotiationSupport - - // KeyLogWriter optionally specifies a destination for TLS master secrets - // in NSS key log format that can be used to allow external programs - // such as Wireshark to decrypt TLS connections. - // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. - // Use of KeyLogWriter compromises security and should only be - // used for debugging. - KeyLogWriter io.Writer - - // mutex protects sessionTicketKeys and autoSessionTicketKeys. - mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means the - // the keys were set with SessionTicketKey or SetSessionTicketKeys. The - // first key is used for new tickets and any subsequent keys can be used to - // decrypt old tickets. The slice contents are not protected by the mutex - // and are immutable. - sessionTicketKeys []ticketKey - // autoSessionTicketKeys is like sessionTicketKeys but is owned by the - // auto-rotation logic. See Config.ticketKeys. - autoSessionTicketKeys []ticketKey -} - -// A RecordLayer handles encrypting and decrypting of TLS messages. -type RecordLayer interface { - SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) - SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) - ReadHandshakeMessage() ([]byte, error) - WriteRecord([]byte) (int, error) - SendAlert(uint8) -} - -type ExtraConfig struct { - // GetExtensions, if not nil, is called before a message that allows - // sending of extensions is sent. - // Currently only implemented for the ClientHello message (for the client) - // and for the EncryptedExtensions message (for the server). - // Only valid for TLS 1.3. - GetExtensions func(handshakeMessageType uint8) []Extension - - // ReceivedExtensions, if not nil, is called when a message that allows the - // inclusion of extensions is received. - // It is called with an empty slice of extensions, if the message didn't - // contain any extensions. - // Currently only implemented for the ClientHello message (sent by the - // client) and for the EncryptedExtensions message (sent by the server). - // Only valid for TLS 1.3. - ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) - - // AlternativeRecordLayer is used by QUIC - AlternativeRecordLayer RecordLayer - - // Enforce the selection of a supported application protocol. - // Only works for TLS 1.3. - // If enabled, client and server have to agree on an application protocol. - // Otherwise, connection establishment fails. - EnforceNextProtoSelection bool - - // If MaxEarlyData is greater than 0, the client will be allowed to send early - // data when resuming a session. - // Requires the AlternativeRecordLayer to be set. - // - // It has no meaning on the client. - MaxEarlyData uint32 - - // The Accept0RTT callback is called when the client offers 0-RTT. - // The server then has to decide if it wants to accept or reject 0-RTT. - // It is only used for servers. - Accept0RTT func(appData []byte) bool - - // 0RTTRejected is called when the server rejectes 0-RTT. - // It is only used for clients. - Rejected0RTT func() - - // If set, the client will export the 0-RTT key when resuming a session that - // allows sending of early data. - // Requires the AlternativeRecordLayer to be set. - // - // It has no meaning to the server. - Enable0RTT bool - - // Is called when the client saves a session ticket to the session ticket. - // This gives the application the opportunity to save some data along with the ticket, - // which can be restored when the session ticket is used. - GetAppDataForSessionState func() []byte - - // Is called when the client uses a session ticket. - // Restores the application data that was saved earlier on GetAppDataForSessionTicket. - SetAppDataFromSessionState func([]byte) -} - -// Clone clones. -func (c *ExtraConfig) Clone() *ExtraConfig { - return &ExtraConfig{ - GetExtensions: c.GetExtensions, - ReceivedExtensions: c.ReceivedExtensions, - AlternativeRecordLayer: c.AlternativeRecordLayer, - EnforceNextProtoSelection: c.EnforceNextProtoSelection, - MaxEarlyData: c.MaxEarlyData, - Enable0RTT: c.Enable0RTT, - Accept0RTT: c.Accept0RTT, - Rejected0RTT: c.Rejected0RTT, - GetAppDataForSessionState: c.GetAppDataForSessionState, - SetAppDataFromSessionState: c.SetAppDataFromSessionState, - } -} - -func (c *ExtraConfig) usesAlternativeRecordLayer() bool { - return c != nil && c.AlternativeRecordLayer != nil -} - -const ( - // ticketKeyNameLen is the number of bytes of identifier that is prepended to - // an encrypted session ticket in order to identify the key used to encrypt it. - ticketKeyNameLen = 16 - - // ticketKeyLifetime is how long a ticket key remains valid and can be used to - // resume a client connection. - ticketKeyLifetime = 7 * 24 * time.Hour // 7 days - - // ticketKeyRotation is how often the server should rotate the session ticket key - // that is used for new tickets. - ticketKeyRotation = 24 * time.Hour -) - -// ticketKey is the internal representation of a session ticket key. -type ticketKey struct { - // keyName is an opaque byte string that serves to identify the session - // ticket key. It's exposed as plaintext in every session ticket. - keyName [ticketKeyNameLen]byte - aesKey [16]byte - hmacKey [16]byte - // created is the time at which this ticket key was created. See Config.ticketKeys. - created time.Time -} - -// ticketKeyFromBytes converts from the external representation of a session -// ticket key to a ticketKey. Externally, session ticket keys are 32 random -// bytes and this function expands that into sufficient name and key material. -func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { - hashed := sha512.Sum512(b[:]) - copy(key.keyName[:], hashed[:ticketKeyNameLen]) - copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) - copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) - key.created = c.time() - return key -} - -// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session -// ticket, and the lifetime we set for tickets we send. -const maxSessionTicketLifetime = 7 * 24 * time.Hour - -// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *config) Clone() *config { - if c == nil { - return nil - } - c.mutex.RLock() - defer c.mutex.RUnlock() - return &config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - GetClientCertificate: c.GetClientCertificate, - GetConfigForClient: c.GetConfigForClient, - VerifyPeerCertificate: c.VerifyPeerCertificate, - VerifyConnection: c.VerifyConnection, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - KeyLogWriter: c.KeyLogWriter, - sessionTicketKeys: c.sessionTicketKeys, - autoSessionTicketKeys: c.autoSessionTicketKeys, - } -} - -// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was -// randomized for backwards compatibility but is not in use. -var deprecatedSessionTicketKey = []byte("DEPRECATED") - -// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is -// randomized if empty, and that sessionTicketKeys is populated from it otherwise. -func (c *config) initLegacySessionTicketKeyRLocked() { - // Don't write if SessionTicketKey is already defined as our deprecated string, - // or if it is defined by the user but sessionTicketKeys is already set. - if c.SessionTicketKey != [32]byte{} && - (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { - return - } - - // We need to write some data, so get an exclusive lock and re-check any conditions. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - if c.SessionTicketKey == [32]byte{} { - if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { - panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) - } - // Write the deprecated prefix at the beginning so we know we created - // it. This key with the DEPRECATED prefix isn't used as an actual - // session ticket key, and is only randomized in case the application - // reuses it for some reason. - copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) - } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { - c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} - } - -} - -// ticketKeys returns the ticketKeys for this connection. -// If configForClient has explicitly set keys, those will -// be returned. Otherwise, the keys on c will be used and -// may be rotated if auto-managed. -// During rotation, any expired session ticket keys are deleted from -// c.sessionTicketKeys. If the session ticket key that is currently -// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) -// is not fresh, then a new session ticket key will be -// created and prepended to c.sessionTicketKeys. -func (c *config) ticketKeys(configForClient *config) []ticketKey { - // If the ConfigForClient callback returned a Config with explicitly set - // keys, use those, otherwise just use the original Config. - if configForClient != nil { - configForClient.mutex.RLock() - if configForClient.SessionTicketsDisabled { - return nil - } - configForClient.initLegacySessionTicketKeyRLocked() - if len(configForClient.sessionTicketKeys) != 0 { - ret := configForClient.sessionTicketKeys - configForClient.mutex.RUnlock() - return ret - } - configForClient.mutex.RUnlock() - } - - c.mutex.RLock() - defer c.mutex.RUnlock() - if c.SessionTicketsDisabled { - return nil - } - c.initLegacySessionTicketKeyRLocked() - if len(c.sessionTicketKeys) != 0 { - return c.sessionTicketKeys - } - // Fast path for the common case where the key is fresh enough. - if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { - return c.autoSessionTicketKeys - } - - // autoSessionTicketKeys are managed by auto-rotation. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - // Re-check the condition in case it changed since obtaining the new lock. - if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { - var newKey [32]byte - if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { - panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) - } - valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) - valid = append(valid, c.ticketKeyFromBytes(newKey)) - for _, k := range c.autoSessionTicketKeys { - // While rotating the current key, also remove any expired ones. - if c.time().Sub(k.created) < ticketKeyLifetime { - valid = append(valid, k) - } - } - c.autoSessionTicketKeys = valid - } - return c.autoSessionTicketKeys -} - -// SetSessionTicketKeys updates the session ticket keys for a server. -// -// The first key will be used when creating new tickets, while all keys can be -// used for decrypting tickets. It is safe to call this function while the -// server is running in order to rotate the session ticket keys. The function -// will panic if keys is empty. -// -// Calling this function will turn off automatic session ticket key rotation. -// -// If multiple servers are terminating connections for the same host they should -// all have the same session ticket keys. If the session ticket keys leaks, -// previously recorded and future TLS connections using those keys might be -// compromised. -func (c *config) SetSessionTicketKeys(keys [][32]byte) { - if len(keys) == 0 { - panic("tls: keys must have at least one key") - } - - newKeys := make([]ticketKey, len(keys)) - for i, bytes := range keys { - newKeys[i] = c.ticketKeyFromBytes(bytes) - } - - c.mutex.Lock() - c.sessionTicketKeys = newKeys - c.mutex.Unlock() -} - -func (c *config) rand() io.Reader { - r := c.Rand - if r == nil { - return rand.Reader - } - return r -} - -func (c *config) time() time.Time { - t := c.Time - if t == nil { - t = time.Now - } - return t() -} - -func (c *config) cipherSuites() []uint16 { - s := c.CipherSuites - if s == nil { - s = defaultCipherSuites() - } - return s -} - -var supportedVersions = []uint16{ - VersionTLS13, - VersionTLS12, - VersionTLS11, - VersionTLS10, -} - -func (c *config) supportedVersions() []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if c != nil && c.MinVersion != 0 && v < c.MinVersion { - continue - } - if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -func (c *config) maxSupportedVersion() uint16 { - supportedVersions := c.supportedVersions() - if len(supportedVersions) == 0 { - return 0 - } - return supportedVersions[0] -} - -// supportedVersionsFromMax returns a list of supported versions derived from a -// legacy maximum version value. Note that only versions supported by this -// library are returned. Any newer peer will use supportedVersions anyway. -func supportedVersionsFromMax(maxVersion uint16) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if v > maxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} - -func (c *config) curvePreferences() []CurveID { - if c == nil || len(c.CurvePreferences) == 0 { - return defaultCurvePreferences - } - return c.CurvePreferences -} - -func (c *config) supportsCurve(curve CurveID) bool { - for _, cc := range c.curvePreferences() { - if cc == curve { - return true - } - } - return false -} - -// mutualVersion returns the protocol version to use given the advertised -// versions of the peer. Priority is given to the peer preference order. -func (c *config) mutualVersion(peerVersions []uint16) (uint16, bool) { - supportedVersions := c.supportedVersions() - for _, peerVersion := range peerVersions { - for _, v := range supportedVersions { - if v == peerVersion { - return v, true - } - } - } - return 0, false -} - -var errNoCertificates = errors.New("tls: no certificates configured") - -// getCertificate returns the best certificate for the given ClientHelloInfo, -// defaulting to the first element of c.Certificates. -func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { - if c.GetCertificate != nil && - (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { - cert, err := c.GetCertificate(clientHello) - if cert != nil || err != nil { - return cert, err - } - } - - if len(c.Certificates) == 0 { - return nil, errNoCertificates - } - - if len(c.Certificates) == 1 { - // There's only one choice, so no point doing any work. - return &c.Certificates[0], nil - } - - if c.NameToCertificate != nil { - name := strings.ToLower(clientHello.ServerName) - if cert, ok := c.NameToCertificate[name]; ok { - return cert, nil - } - if len(name) > 0 { - labels := strings.Split(name, ".") - labels[0] = "*" - wildcardName := strings.Join(labels, ".") - if cert, ok := c.NameToCertificate[wildcardName]; ok { - return cert, nil - } - } - } - - for _, cert := range c.Certificates { - if err := clientHello.SupportsCertificate(&cert); err == nil { - return &cert, nil - } - } - - // If nothing matches, return the first certificate. - return &c.Certificates[0], nil -} - -// SupportsCertificate returns nil if the provided certificate is supported by -// the client that sent the ClientHello. Otherwise, it returns an error -// describing the reason for the incompatibility. -// -// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate -// callback, this method will take into account the associated Config. Note that -// if GetConfigForClient returns a different Config, the change can't be -// accounted for by this method. -// -// This function will call x509.ParseCertificate unless c.Leaf is set, which can -// incur a significant performance cost. -func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { - // Note we don't currently support certificate_authorities nor - // signature_algorithms_cert, and don't check the algorithms of the - // signatures on the chain (which anyway are a SHOULD, see RFC 8446, - // Section 4.4.2.2). - - config := chi.config - if config == nil { - config = &Config{} - } - conf := fromConfig(config) - vers, ok := conf.mutualVersion(chi.SupportedVersions) - if !ok { - return errors.New("no mutually supported protocol versions") - } - - // If the client specified the name they are trying to connect to, the - // certificate needs to be valid for it. - if chi.ServerName != "" { - x509Cert, err := leafCertificate(c) - if err != nil { - return fmt.Errorf("failed to parse certificate: %w", err) - } - if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { - return fmt.Errorf("certificate is not valid for requested server name: %w", err) - } - } - - // supportsRSAFallback returns nil if the certificate and connection support - // the static RSA key exchange, and unsupported otherwise. The logic for - // supporting static RSA is completely disjoint from the logic for - // supporting signed key exchanges, so we just check it as a fallback. - supportsRSAFallback := func(unsupported error) error { - // TLS 1.3 dropped support for the static RSA key exchange. - if vers == VersionTLS13 { - return unsupported - } - // The static RSA key exchange works by decrypting a challenge with the - // RSA private key, not by signing, so check the PrivateKey implements - // crypto.Decrypter, like *rsa.PrivateKey does. - if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { - if _, ok := priv.Public().(*rsa.PublicKey); !ok { - return unsupported - } - } else { - return unsupported - } - // Finally, there needs to be a mutual cipher suite that uses the static - // RSA key exchange instead of ECDHE. - rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - return false - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if rsaCipherSuite == nil { - return unsupported - } - return nil - } - - // If the client sent the signature_algorithms extension, ensure it supports - // schemes we can use with this certificate and TLS version. - if len(chi.SignatureSchemes) > 0 { - if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { - return supportsRSAFallback(err) - } - } - - // In TLS 1.3 we are done because supported_groups is only relevant to the - // ECDHE computation, point format negotiation is removed, cipher suites are - // only relevant to the AEAD choice, and static RSA does not exist. - if vers == VersionTLS13 { - return nil - } - - // The only signed key exchange we support is ECDHE. - if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { - return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) - } - - var ecdsaCipherSuite bool - if priv, ok := c.PrivateKey.(crypto.Signer); ok { - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - var curve CurveID - switch pub.Curve { - case elliptic.P256(): - curve = CurveP256 - case elliptic.P384(): - curve = CurveP384 - case elliptic.P521(): - curve = CurveP521 - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - var curveOk bool - for _, c := range chi.SupportedCurves { - if c == curve && conf.supportsCurve(c) { - curveOk = true - break - } - } - if !curveOk { - return errors.New("client doesn't support certificate curve") - } - ecdsaCipherSuite = true - case ed25519.PublicKey: - if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { - return errors.New("connection doesn't support Ed25519") - } - ecdsaCipherSuite = true - case *rsa.PublicKey: - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - } else { - return supportsRSAFallback(unsupportedCertificateError(c)) - } - - // Make sure that there is a mutually supported cipher suite that works with - // this certificate. Cipher suite selection will then apply the logic in - // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. - cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE == 0 { - return false - } - if c.flags&suiteECSign != 0 { - if !ecdsaCipherSuite { - return false - } - } else { - if ecdsaCipherSuite { - return false - } - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if cipherSuite == nil { - return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) - } - - return nil -} - -// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate -// from the CommonName and SubjectAlternateName fields of each of the leaf -// certificates. -// -// Deprecated: NameToCertificate only allows associating a single certificate -// with a given name. Leave that field nil to let the library select the first -// compatible chain from Certificates. -func (c *config) BuildNameToCertificate() { - c.NameToCertificate = make(map[string]*Certificate) - for i := range c.Certificates { - cert := &c.Certificates[i] - x509Cert, err := leafCertificate(cert) - if err != nil { - continue - } - // If SANs are *not* present, some clients will consider the certificate - // valid for the name in the Common Name. - if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { - c.NameToCertificate[x509Cert.Subject.CommonName] = cert - } - for _, san := range x509Cert.DNSNames { - c.NameToCertificate[san] = cert - } - } -} - -const ( - keyLogLabelTLS12 = "CLIENT_RANDOM" - keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" - keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" - keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" -) - -func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { - if c.KeyLogWriter == nil { - return nil - } - - logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) - - writerMutex.Lock() - _, err := c.KeyLogWriter.Write(logLine) - writerMutex.Unlock() - - return err -} - -// writerMutex protects all KeyLogWriters globally. It is rarely enabled, -// and is only for debugging, so a global mutex saves space. -var writerMutex sync.Mutex - -// A Certificate is a chain of one or more certificates, leaf first. -type Certificate = tls.Certificate - -// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing -// the corresponding c.Certificate[0]. -func leafCertificate(c *Certificate) (*x509.Certificate, error) { - if c.Leaf != nil { - return c.Leaf, nil - } - return x509.ParseCertificate(c.Certificate[0]) -} - -type handshakeMessage interface { - marshal() []byte - unmarshal([]byte) bool -} - -// lruSessionCache is a ClientSessionCache implementation that uses an LRU -// caching strategy. -type lruSessionCache struct { - sync.Mutex - - m map[string]*list.Element - q *list.List - capacity int -} - -type lruSessionCacheEntry struct { - sessionKey string - state *ClientSessionState -} - -// NewLRUClientSessionCache returns a ClientSessionCache with the given -// capacity that uses an LRU strategy. If capacity is < 1, a default capacity -// is used instead. -func NewLRUClientSessionCache(capacity int) ClientSessionCache { - const defaultSessionCacheCapacity = 64 - - if capacity < 1 { - capacity = defaultSessionCacheCapacity - } - return &lruSessionCache{ - m: make(map[string]*list.Element), - q: list.New(), - capacity: capacity, - } -} - -// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry -// corresponding to sessionKey is removed from the cache instead. -func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - if cs == nil { - c.q.Remove(elem) - delete(c.m, sessionKey) - } else { - entry := elem.Value.(*lruSessionCacheEntry) - entry.state = cs - c.q.MoveToFront(elem) - } - return - } - - if c.q.Len() < c.capacity { - entry := &lruSessionCacheEntry{sessionKey, cs} - c.m[sessionKey] = c.q.PushFront(entry) - return - } - - elem := c.q.Back() - entry := elem.Value.(*lruSessionCacheEntry) - delete(c.m, entry.sessionKey) - entry.sessionKey = sessionKey - entry.state = cs - c.q.MoveToFront(elem) - c.m[sessionKey] = elem -} - -// Get returns the ClientSessionState value associated with a given key. It -// returns (nil, false) if no value is found. -func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - c.q.MoveToFront(elem) - return elem.Value.(*lruSessionCacheEntry).state, true - } - return nil, false -} - -var emptyConfig Config - -func defaultConfig() *Config { - return &emptyConfig -} - -var ( - once sync.Once - varDefaultCipherSuites []uint16 - varDefaultCipherSuitesTLS13 []uint16 -) - -func defaultCipherSuites() []uint16 { - once.Do(initDefaultCipherSuites) - return varDefaultCipherSuites -} - -func defaultCipherSuitesTLS13() []uint16 { - once.Do(initDefaultCipherSuites) - return varDefaultCipherSuitesTLS13 -} - -func initDefaultCipherSuites() { - var topCipherSuites []uint16 - - if hasAESGCMHardwareSupport { - // If AES-GCM hardware is provided then prioritise AES-GCM - // cipher suites. - topCipherSuites = []uint16{ - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - } - varDefaultCipherSuitesTLS13 = []uint16{ - TLS_AES_128_GCM_SHA256, - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_256_GCM_SHA384, - } - } else { - // Without AES-GCM hardware, we put the ChaCha20-Poly1305 - // cipher suites first. - topCipherSuites = []uint16{ - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - } - varDefaultCipherSuitesTLS13 = []uint16{ - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - } - } - - varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) - varDefaultCipherSuites = append(varDefaultCipherSuites, topCipherSuites...) - -NextCipherSuite: - for _, suite := range cipherSuites { - if suite.flags&suiteDefaultOff != 0 { - continue - } - for _, existing := range varDefaultCipherSuites { - if existing == suite.id { - continue NextCipherSuite - } - } - varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) - } -} - -func unexpectedMessageError(wanted, got interface{}) error { - return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) -} - -func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { - for _, s := range supportedSignatureAlgorithms { - if s == sigAlg { - return true - } - } - return false -} - -var aesgcmCiphers = map[uint16]bool{ - // 1.2 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, - // 1.3 - TLS_AES_128_GCM_SHA256: true, - TLS_AES_256_GCM_SHA384: true, -} - -var nonAESGCMAEADCiphers = map[uint16]bool{ - // 1.2 - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, - // 1.3 - TLS_CHACHA20_POLY1305_SHA256: true, -} - -// aesgcmPreferred returns whether the first valid cipher in the preference list -// is an AES-GCM cipher, implying the peer has hardware support for it. -func aesgcmPreferred(ciphers []uint16) bool { - for _, cID := range ciphers { - c := cipherSuiteByID(cID) - if c == nil { - c13 := cipherSuiteTLS13ByID(cID) - if c13 == nil { - continue - } - return aesgcmCiphers[cID] - } - return aesgcmCiphers[cID] - } - return false -} - -// deprioritizeAES reorders cipher preference lists by rearranging -// adjacent AEAD ciphers such that AES-GCM based ciphers are moved -// after other AEAD ciphers. It returns a fresh slice. -func deprioritizeAES(ciphers []uint16) []uint16 { - reordered := make([]uint16, len(ciphers)) - copy(reordered, ciphers) - sort.SliceStable(reordered, func(i, j int) bool { - return nonAESGCMAEADCiphers[reordered[i]] && aesgcmCiphers[reordered[j]] - }) - return reordered -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go b/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go deleted file mode 100644 index 97e6ecefd77..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/common_js.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build js - -package qtls - -var ( - hasGCMAsmAMD64 = false - hasGCMAsmARM64 = false - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = false - - hasAESGCMHardwareSupport = false -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go b/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go deleted file mode 100644 index 5e56e0fb36b..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/common_nojs.go +++ /dev/null @@ -1,20 +0,0 @@ -// +build !js - -package qtls - -import ( - "runtime" - - "golang.org/x/sys/cpu" -) - -var ( - hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ - hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || - runtime.GOARCH == "arm64" && hasGCMAsmARM64 || - runtime.GOARCH == "s390x" && hasGCMAsmS390X -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/conn.go b/vendor/github.com/marten-seemann/qtls-go1-16/conn.go deleted file mode 100644 index fa5eb3f1062..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/conn.go +++ /dev/null @@ -1,1536 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// TLS low level connection and record layer - -package qtls - -import ( - "bytes" - "crypto/cipher" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "time" -) - -// A Conn represents a secured connection. -// It implements the net.Conn interface. -type Conn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func() error // (*Conn).clientHandshake or serverHandshake - - // handshakeStatus is 1 if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // This field is only to be accessed with sync/atomic. - handshakeStatus uint32 - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - extraConfig *ExtraConfig - - handshakes int - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // For the client: - // resumptionSecret is the resumption_master_secret for handling - // NewSessionTicket messages. nil if config.SessionTicketsDisabled. - // For the server: - // resumptionSecret is the resumption_master_secret for generating - // NewSessionTicket messages. Only used when the alternative record - // layer is set. nil if config.SessionTicketsDisabled. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []ticketKey - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn - rawInput bytes.Buffer // raw input, starting with a record header - input bytes.Reader // application data waiting to be read, from rawInput.Next - hand bytes.Buffer // handshake data waiting to be read - buffering bool // whether records are buffered in sendBuf - sendBuf []byte // a buffer of records waiting to be sent - - // bytesSent counts the bytes of application data sent. - // packetsSent counts packets. - bytesSent int64 - packetsSent int64 - - // retryCount counts the number of consecutive non-advancing records - // received by Conn.readRecord. That is, records that neither advance the - // handshake, nor deliver application data. Protected by in.Mutex. - retryCount int - - // activeCall is an atomic int32; the low bit is whether Close has - // been called. the rest of the bits are the number of goroutines - // in Conn.Write. - activeCall int32 - - used0RTT bool - - tmp [16]byte -} - -// Access to net.Conn methods. -// Cannot just embed net.Conn because that would -// export the struct field too. - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher interface{} // next encryption state - nextMac hash.Hash // next MAC algorithm - - trafficSecret []byte // current TLS 1.3 traffic secret - - setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) -} - -type permanentError struct { - err net.Error -} - -func (e *permanentError) Error() string { return e.err.Error() } -func (e *permanentError) Unwrap() error { return e.err } -func (e *permanentError) Timeout() bool { return e.err.Timeout() } -func (e *permanentError) Temporary() bool { return false } - -func (hc *halfConn) setErrorLocked(err error) error { - if e, ok := err.(net.Error); ok { - hc.err = &permanentError{err: e} - } else { - hc.err = err - } - return hc.err -} - -// prepareCipherSpec sets the encryption and MAC states -// that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { - hc.version = version - hc.nextCipher = cipher - hc.nextMac = mac -} - -// changeCipherSpec changes the encryption and MAC states -// to the ones previously passed to prepareCipherSpec. -func (hc *halfConn) changeCipherSpec() error { - if hc.nextCipher == nil || hc.version == VersionTLS13 { - return alertInternalError - } - hc.cipher = hc.nextCipher - hc.mac = hc.nextMac - hc.nextCipher = nil - hc.nextMac = nil - for i := range hc.seq { - hc.seq[i] = 0 - } - return nil -} - -func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { - if hc.setKeyCallback != nil { - s := &CipherSuiteTLS13{ - ID: suite.id, - KeyLen: suite.keyLen, - Hash: suite.hash, - AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, - } - hc.setKeyCallback(encLevel, s, trafficSecret) - } -} - -func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { - hc.trafficSecret = secret - key, iv := suite.trafficKey(secret) - hc.cipher = suite.aead(key, iv) - for i := range hc.seq { - hc.seq[i] = 0 - } -} - -// incSeq increments the sequence number. -func (hc *halfConn) incSeq() { - for i := 7; i >= 0; i-- { - hc.seq[i]++ - if hc.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") -} - -// explicitNonceLen returns the number of bytes of explicit nonce or IV included -// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 -// and in certain AEAD modes in TLS 1.2. -func (hc *halfConn) explicitNonceLen() int { - if hc.cipher == nil { - return 0 - } - - switch c := hc.cipher.(type) { - case cipher.Stream: - return 0 - case aead: - return c.explicitNonceLen() - case cbcMode: - // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. - if hc.version >= VersionTLS11 { - return c.BlockSize() - } - return 0 - default: - panic("unknown cipher type") - } -} - -// extractPadding returns, in constant time, the length of the padding to remove -// from the end of payload. It also returns a byte which is equal to 255 if the -// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. -func extractPadding(payload []byte) (toRemove int, good byte) { - if len(payload) < 1 { - return 0, 0 - } - - paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) - // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) - - // The maximum possible padding length plus the actual length field - toCheck := 256 - // The length of the padded data is public, so we can use an if here - if toCheck > len(payload) { - toCheck = len(payload) - } - - for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) - // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) - b := payload[len(payload)-1-i] - good &^= mask&paddingLen ^ mask&b - } - - // We AND together the bits of good and replicate the result across - // all the bits. - good &= good << 4 - good &= good << 2 - good &= good << 1 - good = uint8(int8(good) >> 7) - - // Zero the padding length on error. This ensures any unchecked bytes - // are included in the MAC. Otherwise, an attacker that could - // distinguish MAC failures from padding failures could mount an attack - // similar to POODLE in SSL 3.0: given a good ciphertext that uses a - // full block's worth of padding, replace the final block with another - // block. If the MAC check passed but the padding check failed, the - // last byte of that block decrypted to the block size. - // - // See also macAndPaddingGood logic below. - paddingLen &= good - - toRemove = int(paddingLen) + 1 - return -} - -func roundUp(a, b int) int { - return a + (b-a%b)%b -} - -// cbcMode is an interface for block ciphers using cipher block chaining. -type cbcMode interface { - cipher.BlockMode - SetIV([]byte) -} - -// decrypt authenticates and decrypts the record if protection is active at -// this stage. The returned plaintext might overlap with the input. -func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { - var plaintext []byte - typ := recordType(record[0]) - payload := record[recordHeaderLen:] - - // In TLS 1.3, change_cipher_spec messages are to be ignored without being - // decrypted. See RFC 8446, Appendix D.4. - if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { - return payload, typ, nil - } - - paddingGood := byte(255) - paddingLen := 0 - - explicitNonceLen := hc.explicitNonceLen() - - if hc.cipher != nil { - switch c := hc.cipher.(type) { - case cipher.Stream: - c.XORKeyStream(payload, payload) - case aead: - if len(payload) < explicitNonceLen { - return nil, 0, alertBadRecordMAC - } - nonce := payload[:explicitNonceLen] - if len(nonce) == 0 { - nonce = hc.seq[:] - } - payload = payload[explicitNonceLen:] - - var additionalData []byte - if hc.version == VersionTLS13 { - additionalData = record[:recordHeaderLen] - } else { - additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:3]...) - n := len(payload) - c.Overhead() - additionalData = append(additionalData, byte(n>>8), byte(n)) - } - - var err error - plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) - if err != nil { - return nil, 0, alertBadRecordMAC - } - case cbcMode: - blockSize := c.BlockSize() - minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) - if len(payload)%blockSize != 0 || len(payload) < minPayload { - return nil, 0, alertBadRecordMAC - } - - if explicitNonceLen > 0 { - c.SetIV(payload[:explicitNonceLen]) - payload = payload[explicitNonceLen:] - } - c.CryptBlocks(payload, payload) - - // In a limited attempt to protect against CBC padding oracles like - // Lucky13, the data past paddingLen (which is secret) is passed to - // the MAC function as extra data, to be fed into the HMAC after - // computing the digest. This makes the MAC roughly constant time as - // long as the digest computation is constant time and does not - // affect the subsequent write, modulo cache effects. - paddingLen, paddingGood = extractPadding(payload) - default: - panic("unknown cipher type") - } - - if hc.version == VersionTLS13 { - if typ != recordTypeApplicationData { - return nil, 0, alertUnexpectedMessage - } - if len(plaintext) > maxPlaintext+1 { - return nil, 0, alertRecordOverflow - } - // Remove padding and find the ContentType scanning from the end. - for i := len(plaintext) - 1; i >= 0; i-- { - if plaintext[i] != 0 { - typ = recordType(plaintext[i]) - plaintext = plaintext[:i] - break - } - if i == 0 { - return nil, 0, alertUnexpectedMessage - } - } - } - } else { - plaintext = payload - } - - if hc.mac != nil { - macSize := hc.mac.Size() - if len(payload) < macSize { - return nil, 0, alertBadRecordMAC - } - - n := len(payload) - macSize - paddingLen - n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } - record[3] = byte(n >> 8) - record[4] = byte(n) - remoteMAC := payload[n : n+macSize] - localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) - - // This is equivalent to checking the MACs and paddingGood - // separately, but in constant-time to prevent distinguishing - // padding failures from MAC failures. Depending on what value - // of paddingLen was returned on bad padding, distinguishing - // bad MAC from bad padding can lead to an attack. - // - // See also the logic at the end of extractPadding. - macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) - if macAndPaddingGood != 1 { - return nil, 0, alertBadRecordMAC - } - - plaintext = payload[:n] - } - - hc.incSeq() - return plaintext, typ, nil -} - -func (c *Conn) setAlternativeRecordLayer() { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey - c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey - } -} - -// sliceForAppend extends the input slice by n bytes. head is the full extended -// slice, while tail is the appended part. If the original slice has sufficient -// capacity no allocation is performed. -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -} - -// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which must already contain the record header. -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - if hc.cipher == nil { - return append(record, payload...), nil - } - - var explicitNonce []byte - if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { - record, explicitNonce = sliceForAppend(record, explicitNonceLen) - if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { - // The AES-GCM construction in TLS has an explicit nonce so that the - // nonce can be random. However, the nonce is only 8 bytes which is - // too small for a secure, random nonce. Therefore we use the - // sequence number as the nonce. The 3DES-CBC construction also has - // an 8 bytes nonce but its nonces must be unpredictable (see RFC - // 5246, Appendix F.3), forcing us to use randomness. That's not - // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its similarly small block size - // (see the Sweet32 attack). - copy(explicitNonce, hc.seq[:]) - } else { - if _, err := io.ReadFull(rand, explicitNonce); err != nil { - return nil, err - } - } - } - - var dst []byte - switch c := hc.cipher.(type) { - case cipher.Stream: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - record, dst = sliceForAppend(record, len(payload)+len(mac)) - c.XORKeyStream(dst[:len(payload)], payload) - c.XORKeyStream(dst[len(payload):], mac) - case aead: - nonce := explicitNonce - if len(nonce) == 0 { - nonce = hc.seq[:] - } - - if hc.version == VersionTLS13 { - record = append(record, payload...) - - // Encrypt the actual ContentType and replace the plaintext one. - record = append(record, record[0]) - record[0] = byte(recordTypeApplicationData) - - n := len(payload) + 1 + c.Overhead() - record[3] = byte(n >> 8) - record[4] = byte(n) - - record = c.Seal(record[:recordHeaderLen], - nonce, record[recordHeaderLen:], record[:recordHeaderLen]) - } else { - additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:recordHeaderLen]...) - record = c.Seal(record, nonce, payload, additionalData) - } - case cbcMode: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - blockSize := c.BlockSize() - plaintextLen := len(payload) + len(mac) - paddingLen := blockSize - plaintextLen%blockSize - record, dst = sliceForAppend(record, plaintextLen+paddingLen) - copy(dst, payload) - copy(dst[len(payload):], mac) - for i := plaintextLen; i < len(dst); i++ { - dst[i] = byte(paddingLen - 1) - } - if len(explicitNonce) > 0 { - c.SetIV(explicitNonce) - } - c.CryptBlocks(dst, dst) - default: - panic("unknown cipher type") - } - - // Update length to include nonce, MAC and any block padding needed. - n := len(record) - recordHeaderLen - record[3] = byte(n >> 8) - record[4] = byte(n) - hc.incSeq() - - return record, nil -} - -// RecordHeaderError is returned when a TLS record header is invalid. -type RecordHeaderError struct { - // Msg contains a human readable string that describes the error. - Msg string - // RecordHeader contains the five bytes of TLS record header that - // triggered the error. - RecordHeader [5]byte - // Conn provides the underlying net.Conn in the case that a client - // sent an initial handshake that didn't look like TLS. - // It is nil if there's already been a handshake or a TLS alert has - // been written to the connection. - Conn net.Conn -} - -func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } - -func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { - err.Msg = msg - err.Conn = conn - copy(err.RecordHeader[:], c.rawInput.Bytes()) - return err -} - -func (c *Conn) readRecord() error { - return c.readRecordOrCCS(false) -} - -func (c *Conn) readChangeCipherSpec() error { - return c.readRecordOrCCS(true) -} - -// readRecordOrCCS reads one or more TLS records from the connection and -// updates the record layer state. Some invariants: -// * c.in must be locked -// * c.input must be empty -// During the handshake one and only one of the following will happen: -// - c.hand grows -// - c.in.changeCipherSpec is called -// - an error is returned -// After the handshake one and only one of the following will happen: -// - c.hand grows -// - c.input is set -// - an error is returned -func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { - if c.in.err != nil { - return c.in.err - } - handshakeComplete := c.handshakeComplete() - - // This function modifies c.rawInput, which owns the c.input memory. - if c.input.Len() != 0 { - return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) - } - c.input.Reset(nil) - - // Read header, payload. - if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { - // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify - // is an error, but popular web sites seem to do this, so we accept it - // if and only if at the record boundary. - if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { - err = io.EOF - } - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - hdr := c.rawInput.Bytes()[:recordHeaderLen] - typ := recordType(hdr[0]) - - // No valid TLS record has a type of 0x80, however SSLv2 handshakes - // start with a uint16 length where the MSB is set and the first record - // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests - // an SSLv2 client. - if !handshakeComplete && typ == 0x80 { - c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) - } - - vers := uint16(hdr[1])<<8 | uint16(hdr[2]) - n := int(hdr[3])<<8 | int(hdr[4]) - if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { - c.sendAlert(alertProtocolVersion) - msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if !c.haveVers { - // First message, be extra suspicious: this might not be a TLS - // client. Bail out before reading a full 'body', if possible. - // The current max version is 3.3 so if the version is >= 16.0, - // it's probably not real. - if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { - return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) - } - } - if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { - c.sendAlert(alertRecordOverflow) - msg := fmt.Sprintf("oversized record received with length %d", n) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - - // Process message. - record := c.rawInput.Next(recordHeaderLen + n) - data, typ, err := c.in.decrypt(record) - if err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - if len(data) > maxPlaintext { - return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) - } - - // Application Data messages are always protected. - if c.in.cipher == nil && typ == recordTypeApplicationData { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { - // This is a state-advancing message: reset the retry count. - c.retryCount = 0 - } - - // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. - if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - switch typ { - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - - case recordTypeAlert: - if len(data) != 2 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if alert(data[1]) == alertCloseNotify { - return c.in.setErrorLocked(io.EOF) - } - if c.vers == VersionTLS13 { - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - } - switch data[0] { - case alertLevelWarning: - // Drop the record on the floor and retry. - return c.retryReadRecord(expectChangeCipherSpec) - case alertLevelError: - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - case recordTypeChangeCipherSpec: - if len(data) != 1 || data[0] != 1 { - return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) - } - // Handshake messages are not allowed to fragment across the CCS. - if c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // In TLS 1.3, change_cipher_spec records are ignored until the - // Finished. See RFC 8446, Appendix D.4. Note that according to Section - // 5, a server can send a ChangeCipherSpec before its ServerHello, when - // c.vers is still unset. That's not useful though and suspicious if the - // server then selects a lower protocol version, so don't allow that. - if c.vers == VersionTLS13 { - return c.retryReadRecord(expectChangeCipherSpec) - } - if !expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if err := c.in.changeCipherSpec(); err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - - case recordTypeApplicationData: - if !handshakeComplete || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // Some OpenSSL servers send empty records in order to randomize the - // CBC IV. Ignore a limited number of empty records. - if len(data) == 0 { - return c.retryReadRecord(expectChangeCipherSpec) - } - // Note that data is owned by c.rawInput, following the Next call above, - // to avoid copying the plaintext. This is safe because c.rawInput is - // not read from or written to until c.input is drained. - c.input.Reset(data) - - case recordTypeHandshake: - if len(data) == 0 || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - c.hand.Write(data) - } - - return nil -} - -// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like -// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. -func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many ignored records")) - } - return c.readRecordOrCCS(expectChangeCipherSpec) -} - -// atLeastReader reads from R, stopping with EOF once at least N bytes have been -// read. It is different from an io.LimitedReader in that it doesn't cut short -// the last Read call, and in that it considers an early EOF an error. -type atLeastReader struct { - R io.Reader - N int64 -} - -func (r *atLeastReader) Read(p []byte) (int, error) { - if r.N <= 0 { - return 0, io.EOF - } - n, err := r.R.Read(p) - r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 - if r.N > 0 && err == io.EOF { - return n, io.ErrUnexpectedEOF - } - if r.N <= 0 && err == nil { - return n, io.EOF - } - return n, err -} - -// readFromUntil reads from r into c.rawInput until c.rawInput contains -// at least n bytes or else returns an error. -func (c *Conn) readFromUntil(r io.Reader, n int) error { - if c.rawInput.Len() >= n { - return nil - } - needs := n - c.rawInput.Len() - // There might be extra input waiting on the wire. Make a best effort - // attempt to fetch it so that it can be used in (*Conn).Read to - // "predict" closeNotify alerts. - c.rawInput.Grow(needs + bytes.MinRead) - _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) - return err -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlertLocked(err alert) error { - switch err { - case alertNoRenegotiation, alertCloseNotify: - c.tmp[0] = alertLevelWarning - default: - c.tmp[0] = alertLevelError - } - c.tmp[1] = byte(err) - - _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) - if err == alertCloseNotify { - // closeNotify is a special case in that it isn't an error. - return writeErr - } - - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlert(err alert) error { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) - return &net.OpError{Op: "local error", Err: err} - } - - c.out.Lock() - defer c.out.Unlock() - return c.sendAlertLocked(err) -} - -const ( - // tcpMSSEstimate is a conservative estimate of the TCP maximum segment - // size (MSS). A constant is used, rather than querying the kernel for - // the actual MSS, to avoid complexity. The value here is the IPv6 - // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 - // bytes) and a TCP header with timestamps (32 bytes). - tcpMSSEstimate = 1208 - - // recordSizeBoostThreshold is the number of bytes of application data - // sent after which the TLS record size will be increased to the - // maximum. - recordSizeBoostThreshold = 128 * 1024 -) - -// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the -// next application data record. There is the following trade-off: -// -// - For latency-sensitive applications, such as web browsing, each TLS -// record should fit in one TCP segment. -// - For throughput-sensitive applications, such as large file transfers, -// larger TLS records better amortize framing and encryption overheads. -// -// A simple heuristic that works well in practice is to use small records for -// the first 1MB of data, then use larger records for subsequent data, and -// reset back to smaller records after the connection becomes idle. See "High -// Performance Web Networking", Chapter 4, or: -// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ -// -// In the interests of simplicity and determinism, this code does not attempt -// to reset the record size once the connection is idle, however. -func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { - if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { - return maxPlaintext - } - - if c.bytesSent >= recordSizeBoostThreshold { - return maxPlaintext - } - - // Subtract TLS overheads to get the maximum payload size. - payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() - if c.out.cipher != nil { - switch ciph := c.out.cipher.(type) { - case cipher.Stream: - payloadBytes -= c.out.mac.Size() - case cipher.AEAD: - payloadBytes -= ciph.Overhead() - case cbcMode: - blockSize := ciph.BlockSize() - // The payload must fit in a multiple of blockSize, with - // room for at least one padding byte. - payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 - // The MAC is appended before padding so affects the - // payload size directly. - payloadBytes -= c.out.mac.Size() - default: - panic("unknown cipher type") - } - } - if c.vers == VersionTLS13 { - payloadBytes-- // encrypted ContentType - } - - // Allow packet growth in arithmetic progression up to max. - pkt := c.packetsSent - c.packetsSent++ - if pkt > 1000 { - return maxPlaintext // avoid overflow in multiply below - } - - n := payloadBytes * int(pkt+1) - if n > maxPlaintext { - n = maxPlaintext - } - return n -} - -func (c *Conn) write(data []byte) (int, error) { - if c.buffering { - c.sendBuf = append(c.sendBuf, data...) - return len(data), nil - } - - n, err := c.conn.Write(data) - c.bytesSent += int64(n) - return n, err -} - -func (c *Conn) flush() (int, error) { - if len(c.sendBuf) == 0 { - return 0, nil - } - - n, err := c.conn.Write(c.sendBuf) - c.bytesSent += int64(n) - c.sendBuf = nil - c.buffering = false - return n, err -} - -// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. -var outBufPool = sync.Pool{ - New: func() interface{} { - return new([]byte) - }, -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - for len(data) > 0 { - m := len(data) - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = VersionTLS10 - } else if vers == VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - } - - if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(err.(alert)) - } - } - - return n, nil -} - -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - if typ == recordTypeChangeCipherSpec { - return len(data), nil - } - return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) - } - - c.out.Lock() - defer c.out.Unlock() - - return c.writeRecordLocked(typ, data) -} - -// readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (interface{}, error) { - var data []byte - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - var err error - data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() - if err != nil { - return nil, err - } - } else { - for c.hand.Len() < 4 { - if err := c.readRecord(); err != nil { - return nil, err - } - } - - data = c.hand.Bytes() - n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if n > maxHandshake { - c.sendAlertLocked(alertInternalError) - return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) - } - for c.hand.Len() < 4+n { - if err := c.readRecord(); err != nil { - return nil, err - } - } - data = c.hand.Next(4 + n) - } - var m handshakeMessage - switch data[0] { - case typeHelloRequest: - m = new(helloRequestMsg) - case typeClientHello: - m = new(clientHelloMsg) - case typeServerHello: - m = new(serverHelloMsg) - case typeNewSessionTicket: - if c.vers == VersionTLS13 { - m = new(newSessionTicketMsgTLS13) - } else { - m = new(newSessionTicketMsg) - } - case typeCertificate: - if c.vers == VersionTLS13 { - m = new(certificateMsgTLS13) - } else { - m = new(certificateMsg) - } - case typeCertificateRequest: - if c.vers == VersionTLS13 { - m = new(certificateRequestMsgTLS13) - } else { - m = &certificateRequestMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - } - case typeCertificateStatus: - m = new(certificateStatusMsg) - case typeServerKeyExchange: - m = new(serverKeyExchangeMsg) - case typeServerHelloDone: - m = new(serverHelloDoneMsg) - case typeClientKeyExchange: - m = new(clientKeyExchangeMsg) - case typeCertificateVerify: - m = &certificateVerifyMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - case typeFinished: - m = new(finishedMsg) - case typeEncryptedExtensions: - m = new(encryptedExtensionsMsg) - case typeEndOfEarlyData: - m = new(endOfEarlyDataMsg) - case typeKeyUpdate: - m = new(keyUpdateMsg) - default: - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - // The handshake message unmarshalers - // expect to be able to keep references to data, - // so pass in a fresh copy that won't be overwritten. - data = append([]byte(nil), data...) - - if !m.unmarshal(data) { - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - return m, nil -} - -var ( - errShutdown = errors.New("tls: protocol is shutdown") -) - -// Write writes data to the connection. -// -// As Write calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Write is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Write(b []byte) (int, error) { - // interlock with Close below - for { - x := atomic.LoadInt32(&c.activeCall) - if x&1 != 0 { - return 0, net.ErrClosed - } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { - break - } - } - defer atomic.AddInt32(&c.activeCall, -2) - - if err := c.Handshake(); err != nil { - return 0, err - } - - c.out.Lock() - defer c.out.Unlock() - - if err := c.out.err; err != nil { - return 0, err - } - - if !c.handshakeComplete() { - return 0, alertInternalError - } - - if c.closeNotifySent { - return 0, errShutdown - } - - // TLS 1.0 is susceptible to a chosen-plaintext - // attack when using block mode ciphers due to predictable IVs. - // This can be prevented by splitting each Application Data - // record into two records, effectively randomizing the IV. - // - // https://www.openssl.org/~bodo/tls-cbc.txt - // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 - // https://www.imperialviolet.org/2012/01/15/beastfollowup.html - - var m int - if len(b) > 1 && c.vers == VersionTLS10 { - if _, ok := c.out.cipher.(cipher.BlockMode); ok { - n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) - if err != nil { - return n, c.out.setErrorLocked(err) - } - m, b = 1, b[1:] - } - } - - n, err := c.writeRecordLocked(recordTypeApplicationData, b) - return n + m, c.out.setErrorLocked(err) -} - -// handleRenegotiation processes a HelloRequest handshake message. -func (c *Conn) handleRenegotiation() error { - if c.vers == VersionTLS13 { - return errors.New("tls: internal error: unexpected renegotiation") - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - helloReq, ok := msg.(*helloRequestMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(helloReq, msg) - } - - if !c.isClient { - return c.sendAlert(alertNoRenegotiation) - } - - switch c.config.Renegotiation { - case RenegotiateNever: - return c.sendAlert(alertNoRenegotiation) - case RenegotiateOnceAsClient: - if c.handshakes > 1 { - return c.sendAlert(alertNoRenegotiation) - } - case RenegotiateFreelyAsClient: - // Ok. - default: - c.sendAlert(alertInternalError) - return errors.New("tls: unknown Renegotiation value") - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { - c.handshakes++ - } - return c.handshakeErr -} - -func (c *Conn) HandlePostHandshakeMessage() error { - return c.handlePostHandshakeMessage() -} - -// handlePostHandshakeMessage processes a handshake message arrived after the -// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. -func (c *Conn) handlePostHandshakeMessage() error { - if c.vers != VersionTLS13 { - return c.handleRenegotiation() - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) - } - - switch msg := msg.(type) { - case *newSessionTicketMsgTLS13: - return c.handleNewSessionTicket(msg) - case *keyUpdateMsg: - return c.handleKeyUpdate(msg) - default: - c.sendAlert(alertUnexpectedMessage) - return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) - } -} - -func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil { - return c.in.setErrorLocked(c.sendAlert(alertInternalError)) - } - - newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, newSecret) - - if keyUpdate.updateRequested { - c.out.Lock() - defer c.out.Unlock() - - msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) - if err != nil { - // Surface the error at the next write. - c.out.setErrorLocked(err) - return nil - } - - newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, newSecret) - } - - return nil -} - -// Read reads data from the connection. -// -// As Read calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Read is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Read(b []byte) (int, error) { - if err := c.Handshake(); err != nil { - return 0, err - } - if len(b) == 0 { - // Put this after Handshake, in case people were calling - // Read(nil) for the side effect of the Handshake. - return 0, nil - } - - c.in.Lock() - defer c.in.Unlock() - - for c.input.Len() == 0 { - if err := c.readRecord(); err != nil { - return 0, err - } - for c.hand.Len() > 0 { - if err := c.handlePostHandshakeMessage(); err != nil { - return 0, err - } - } - } - - n, _ := c.input.Read(b) - - // If a close-notify alert is waiting, read it so that we can return (n, - // EOF) instead of (n, nil), to signal to the HTTP response reading - // goroutine that the connection is now closed. This eliminates a race - // where the HTTP response reading goroutine would otherwise not observe - // the EOF until its next read, by which time a client goroutine might - // have already tried to reuse the HTTP connection for a new request. - // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 - if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && - recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { - if err := c.readRecord(); err != nil { - return n, err // will be io.EOF on closeNotify - } - } - - return n, nil -} - -// Close closes the connection. -func (c *Conn) Close() error { - // Interlock with Conn.Write above. - var x int32 - for { - x = atomic.LoadInt32(&c.activeCall) - if x&1 != 0 { - return net.ErrClosed - } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { - break - } - } - if x != 0 { - // io.Writer and io.Closer should not be used concurrently. - // If Close is called while a Write is currently in-flight, - // interpret that as a sign that this Close is really just - // being used to break the Write and/or clean up resources and - // avoid sending the alertCloseNotify, which may block - // waiting on handshakeMutex or the c.out mutex. - return c.conn.Close() - } - - var alertErr error - if c.handshakeComplete() { - if err := c.closeNotify(); err != nil { - alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) - } - } - - if err := c.conn.Close(); err != nil { - return err - } - return alertErr -} - -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - if !c.handshakeComplete() { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - c.out.Lock() - defer c.out.Unlock() - - if !c.closeNotifySent { - // Set a Write Deadline to prevent possibly blocking forever. - c.SetWriteDeadline(time.Now().Add(time.Second * 5)) - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - // Any subsequent writes will fail. - c.SetWriteDeadline(time.Now()) - } - return c.closeNotifyErr -} - -// Handshake runs the client or server handshake -// protocol if it has not yet been run. -// -// Most uses of this package need not call Handshake explicitly: the -// first Read or Write will call it automatically. -// -// For control over canceling or setting a timeout on a handshake, use -// the Dialer's DialContext method. -func (c *Conn) Handshake() error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.handshakeComplete() { - return nil - } - - c.in.Lock() - defer c.in.Unlock() - - c.handshakeErr = c.handshakeFn() - if c.handshakeErr == nil { - c.handshakes++ - } else { - // If an error occurred during the handshake try to flush the - // alert that might be left in the buffer. - c.flush() - } - - if c.handshakeErr == nil && !c.handshakeComplete() { - c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") - } - - return c.handshakeErr -} - -// ConnectionState returns basic TLS details about the connection. -func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return c.connectionStateLocked() -} - -// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. -func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return ConnectionStateWith0RTT{ - ConnectionState: c.connectionStateLocked(), - Used0RTT: c.used0RTT, - } -} - -func (c *Conn) connectionStateLocked() ConnectionState { - var state connectionState - state.HandshakeComplete = c.handshakeComplete() - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = true - state.ServerName = c.serverName - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial - } else { - state.ekm = c.ekm - } - return toConnectionState(state) -} - -// OCSPResponse returns the stapled OCSP response from the TLS server, if -// any. (Only valid for client connections.) -func (c *Conn) OCSPResponse() []byte { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.ocspResponse -} - -// VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error -// describing the problem. -func (c *Conn) VerifyHostname(host string) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.isClient { - return errors.New("tls: VerifyHostname called on TLS server connection") - } - if !c.handshakeComplete() { - return errors.New("tls: handshake has not yet been performed") - } - if len(c.verifiedChains) == 0 { - return errors.New("tls: handshake did not verify certificate chain") - } - return c.peerCertificates[0].VerifyHostname(host) -} - -func (c *Conn) handshakeComplete() bool { - return atomic.LoadUint32(&c.handshakeStatus) == 1 -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go deleted file mode 100644 index a447061ae05..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client.go +++ /dev/null @@ -1,1105 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "strings" - "sync/atomic" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -const clientSessionStateVersion = 1 - -type clientHandshakeState struct { - c *Conn - serverHello *serverHelloMsg - hello *clientHelloMsg - suite *cipherSuite - finishedHash finishedHash - masterSecret []byte - session *clientSessionState -} - -func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { - config := c.config - if len(config.ServerName) == 0 && !config.InsecureSkipVerify { - return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") - } - - nextProtosLength := 0 - for _, proto := range config.NextProtos { - if l := len(proto); l == 0 || l > 255 { - return nil, nil, errors.New("tls: invalid NextProtos value") - } else { - nextProtosLength += 1 + l - } - } - if nextProtosLength > 0xffff { - return nil, nil, errors.New("tls: NextProtos values too large") - } - - var supportedVersions []uint16 - var clientHelloVersion uint16 - if c.extraConfig.usesAlternativeRecordLayer() { - if config.maxSupportedVersion() < VersionTLS13 { - return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") - } - // Only offer TLS 1.3 when QUIC is used. - supportedVersions = []uint16{VersionTLS13} - clientHelloVersion = VersionTLS13 - } else { - supportedVersions = config.supportedVersions() - if len(supportedVersions) == 0 { - return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") - } - clientHelloVersion = config.maxSupportedVersion() - } - - // The version at the beginning of the ClientHello was capped at TLS 1.2 - // for compatibility reasons. The supported_versions extension is used - // to negotiate versions now. See RFC 8446, Section 4.2.1. - if clientHelloVersion > VersionTLS12 { - clientHelloVersion = VersionTLS12 - } - - hello := &clientHelloMsg{ - vers: clientHelloVersion, - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), - ocspStapling: true, - scts: true, - serverName: hostnameInSNI(config.ServerName), - supportedCurves: config.curvePreferences(), - supportedPoints: []uint8{pointFormatUncompressed}, - secureRenegotiationSupported: true, - alpnProtocols: config.NextProtos, - supportedVersions: supportedVersions, - } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - possibleCipherSuites := config.cipherSuites() - hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) - - // add non-TLS 1.3 cipher suites - if c.config.MinVersion <= VersionTLS12 { - for _, suiteId := range possibleCipherSuites { - for _, suite := range cipherSuites { - if suite.id != suiteId { - continue - } - // Don't advertise TLS 1.2-only cipher suites unless - // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { - break - } - hello.cipherSuites = append(hello.cipherSuites, suiteId) - break - } - } - } - - _, err := io.ReadFull(config.rand(), hello.random) - if err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { - hello.sessionId = make([]byte, 32) - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - } - - if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms - } - - var params ecdheParameters - if hello.supportedVersions[0] == VersionTLS13 { - var hasTLS13CipherSuite bool - // add TLS 1.3 cipher suites - for _, suiteID := range possibleCipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - hasTLS13CipherSuite = true - hello.cipherSuites = append(hello.cipherSuites, suiteID) - } - } - } - if !hasTLS13CipherSuite { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13()...) - } - - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err = generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} - } - - if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { - hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) - } - - return hello, params, nil -} - -func (c *Conn) clientHandshake() (err error) { - if c.config == nil { - c.config = fromConfig(defaultConfig()) - } - c.setAlternativeRecordLayer() - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheParams, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) - if cacheKey != "" && session != nil { - var deletedTicket bool - if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { - // don't reuse a session ticket that enabled 0-RTT - c.config.ClientSessionCache.Put(cacheKey, nil) - deletedTicket = true - - if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { - h := suite.hash.New() - h.Write(hello.marshal()) - clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) - c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) - if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - } - } - if !deletedTicket { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - }() - } - } - - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: c, - serverHello: serverHello, - hello: hello, - ecdheParams: ecdheParams, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: c, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - if cacheKey != "" && hs.session != nil && session != hs.session { - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) - } - - return nil -} - -// extract the app data saved in the session.nonce, -// and set the session.nonce to the actual nonce value -func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { - s := cryptobyte.String(session.nonce) - var version uint16 - if !s.ReadUint16(&version) { - return 0, nil, false - } - if version != clientSessionStateVersion { - return 0, nil, false - } - var maxEarlyData uint32 - if !s.ReadUint32(&maxEarlyData) { - return 0, nil, false - } - var appData []byte - if !readUint16LengthPrefixed(&s, &appData) { - return 0, nil, false - } - var nonce []byte - if !readUint16LengthPrefixed(&s, &nonce) { - return 0, nil, false - } - session.nonce = nonce - return maxEarlyData, appData, true -} - -func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *clientSessionState, earlySecret, binderKey []byte) { - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil - } - - hello.ticketSupported = true - - if hello.supportedVersions[0] == VersionTLS13 { - // Require DHE on resumption as it guarantees forward secrecy against - // compromise of the session ticket key. See RFC 8446, Section 4.2.9. - hello.pskModes = []uint8{pskModeDHE} - } - - // Session resumption is not allowed if renegotiating because - // renegotiation is primarily used to allow a client to send a client - // certificate, which would be skipped if session resumption occurred. - if c.handshakes != 0 { - return "", nil, nil, nil - } - - // Try to resume a previously negotiated TLS session, if available. - cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - sess, ok := c.config.ClientSessionCache.Get(cacheKey) - if !ok || sess == nil { - return cacheKey, nil, nil, nil - } - session = fromClientSessionState(sess) - - var appData []byte - var maxEarlyData uint32 - if session.vers == VersionTLS13 { - var ok bool - maxEarlyData, appData, ok = c.decodeSessionState(session) - if !ok { // delete it, if parsing failed - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - } - - // Check that version used for the previous session is still valid. - versOk := false - for _, v := range hello.supportedVersions { - if v == session.vers { - versOk = true - break - } - } - if !versOk { - return cacheKey, nil, nil, nil - } - - // Check that the cached server certificate is not expired, and that it's - // valid for the ServerName. This should be ensured by the cache key, but - // protect the application from a faulty ClientSessionCache implementation. - if !c.config.InsecureSkipVerify { - if len(session.verifiedChains) == 0 { - // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil - } - serverCert := session.serverCertificates[0] - if c.config.time().After(serverCert.NotAfter) { - // Expired certificate, delete the entry. - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil - } - } - - if session.vers != VersionTLS13 { - // In TLS 1.2 the cipher suite must match the resumed session. Ensure we - // are still offering it. - if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil - } - - hello.sessionTicket = session.sessionTicket - return - } - - // Check that the session ticket is not expired. - if c.config.time().After(session.useBy) { - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - - // In TLS 1.3 the KDF hash must match the resumed session. Ensure we - // offer at least one cipher suite with that hash. - cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) - if cipherSuite == nil { - return cacheKey, nil, nil, nil - } - cipherSuiteOk := false - for _, offeredID := range hello.cipherSuites { - offeredSuite := cipherSuiteTLS13ByID(offeredID) - if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return cacheKey, nil, nil, nil - } - - // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. - ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) - identity := pskIdentity{ - label: session.sessionTicket, - obfuscatedTicketAge: ticketAge + session.ageAdd, - } - hello.pskIdentities = []pskIdentity{identity} - hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} - - // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - psk := cipherSuite.expandLabel(session.masterSecret, "resumption", - session.nonce, cipherSuite.hash.Size()) - earlySecret = cipherSuite.extract(psk, nil) - binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) - if c.extraConfig != nil { - hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 - } - transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) - pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) - - if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { - c.extraConfig.SetAppDataFromSessionState(appData) - } - return -} - -func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { - peerVersion := serverHello.vers - if serverHello.supportedVersion != 0 { - peerVersion = serverHello.supportedVersion - } - - vers, ok := c.config.mutualVersion([]uint16{peerVersion}) - if !ok { - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) - } - - c.vers = vers - c.haveVers = true - c.in.version = vers - c.out.version = vers - - return nil -} - -// Does the handshake, either a full one or resumes old session. Requires hs.c, -// hs.hello, hs.serverHello, and, optionally, hs.session to be set. -func (hs *clientHandshakeState) handshake() error { - c := hs.c - - isResume, err := hs.processServerHello() - if err != nil { - return err - } - - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - - // No signatures of the handshake are needed in a resumption. - // Otherwise, in a full handshake, if we don't have any certificates - // configured then we will never send a CertificateVerify message and - // thus no signatures are needed in that case either. - if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { - hs.finishedHash.discardHandshakeBuffer() - } - - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) - - c.buffering = true - c.didResume = isResume - if isResume { - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = false - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } else { - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = true - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -func (hs *clientHandshakeState) pickCipherSuite() error { - if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { - hs.c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") - } - - hs.c.cipherSuite = hs.suite.id - return nil -} - -func (hs *clientHandshakeState) doFullHandshake() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - certMsg, ok := msg.(*certificateMsg) - if !ok || len(certMsg.certificates) == 0 { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.finishedHash.Write(certMsg.marshal()) - - msg, err = c.readHandshake() - if err != nil { - return err - } - - cs, ok := msg.(*certificateStatusMsg) - if ok { - // RFC4366 on Certificate Status Request: - // The server MAY return a "certificate_status" message. - - if !hs.serverHello.ocspStapling { - // If a server returns a "CertificateStatus" message, then the - // server MUST have included an extension of type "status_request" - // with empty "extension_data" in the extended server hello. - - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received unexpected CertificateStatus message") - } - hs.finishedHash.Write(cs.marshal()) - - c.ocspResponse = cs.response - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - if c.handshakes == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - - keyAgreement := hs.suite.ka(c.vers) - - skx, ok := msg.(*serverKeyExchangeMsg) - if ok { - hs.finishedHash.Write(skx.marshal()) - err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) - if err != nil { - c.sendAlert(alertUnexpectedMessage) - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - var chainToSend *Certificate - var certRequested bool - certReq, ok := msg.(*certificateRequestMsg) - if ok { - certRequested = true - hs.finishedHash.Write(certReq.marshal()) - - cri := certificateRequestInfoFromMsg(c.vers, certReq) - if chainToSend, err = c.getClientCertificate(cri); err != nil { - c.sendAlert(alertInternalError) - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - shd, ok := msg.(*serverHelloDoneMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(shd, msg) - } - hs.finishedHash.Write(shd.marshal()) - - // If the server requested a certificate then we have to send a - // Certificate message, even if it's empty because we don't have a - // certificate to send. - if certRequested { - certMsg = new(certificateMsg) - certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - } - - preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { - return err - } - } - - if chainToSend != nil && len(chainToSend.Certificate) > 0 { - certVerify := &certificateVerifyMsg{} - - key, ok := chainToSend.PrivateKey.(crypto.Signer) - if !ok { - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - certVerify.hasSignatureAlgorithm = true - certVerify.signatureAlgorithm = signatureAlgorithm - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { - return err - } - } - - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to write to key log: " + err.Error()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *clientHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - if hs.suite.cipher != nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) - c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) - return nil -} - -func (hs *clientHandshakeState) serverResumedSession() bool { - // If the server responded with the same sessionId then it means the - // sessionTicket is being used to resume a TLS session. - return hs.session != nil && hs.hello.sessionId != nil && - bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) -} - -func (hs *clientHandshakeState) processServerHello() (bool, error) { - c := hs.c - - if err := hs.pickCipherSuite(); err != nil { - return false, err - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertUnexpectedMessage) - return false, errors.New("tls: server selected unsupported compression format") - } - - if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { - c.secureRenegotiation = true - if len(hs.serverHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: initial handshake had non-empty renegotiation extension") - } - } - - if c.handshakes > 0 && c.secureRenegotiation { - var expectedSecureRenegotiation [24]byte - copy(expectedSecureRenegotiation[:], c.clientFinished[:]) - copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) - if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: incorrect renegotiation extension contents") - } - } - - if hs.serverHello.alpnProtocol != "" { - if len(hs.hello.alpnProtocols) == 0 { - c.sendAlert(alertUnsupportedExtension) - return false, errors.New("tls: server advertised unrequested ALPN extension") - } - if mutualProtocol([]string{hs.serverHello.alpnProtocol}, hs.hello.alpnProtocols) == "" { - c.sendAlert(alertUnsupportedExtension) - return false, errors.New("tls: server selected unadvertised ALPN protocol") - } - c.clientProtocol = hs.serverHello.alpnProtocol - } - - c.scts = hs.serverHello.scts - - if !hs.serverResumedSession() { - return false, nil - } - - if hs.session.vers != c.vers { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different version") - } - - if hs.session.cipherSuite != hs.suite.id { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different cipher suite") - } - - // Restore masterSecret, peerCerts, and ocspResponse from previous state - hs.masterSecret = hs.session.masterSecret - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - // Let the ServerHello SCTs override the session SCTs from the original - // connection, if any are provided - if len(c.scts) == 0 && len(hs.session.scts) != 0 { - c.scts = hs.session.scts - } - - return true, nil -} - -func (hs *clientHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - serverFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverFinished, msg) - } - - verify := hs.finishedHash.serverSum(hs.masterSecret) - if len(verify) != len(serverFinished.verifyData) || - subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server's Finished message was incorrect") - } - hs.finishedHash.Write(serverFinished.marshal()) - copy(out, verify) - return nil -} - -func (hs *clientHandshakeState) readSessionTicket() error { - if !hs.serverHello.ticketSupported { - return nil - } - - c := hs.c - msg, err := c.readHandshake() - if err != nil { - return err - } - sessionTicketMsg, ok := msg.(*newSessionTicketMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(sessionTicketMsg, msg) - } - hs.finishedHash.Write(sessionTicketMsg.marshal()) - - hs.session = &clientSessionState{ - sessionTicket: sessionTicketMsg.ticket, - vers: c.vers, - cipherSuite: hs.suite.id, - masterSecret: hs.masterSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - return nil -} - -func (hs *clientHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - copy(out, finished.verifyData) - return nil -} - -// verifyServerCertificate parses and verifies the provided chain, setting -// c.verifiedChains and c.peerCertificates or sending the appropriate alert. -func (c *Conn) verifyServerCertificate(certificates [][]byte) error { - certs := make([]*x509.Certificate, len(certificates)) - for i, asn1Data := range certificates { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse certificate from server: " + err.Error()) - } - certs[i] = cert - } - - if !c.config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: c.config.RootCAs, - CurrentTime: c.config.time(), - DNSName: c.config.ServerName, - Intermediates: x509.NewCertPool(), - } - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - var err error - c.verifiedChains, err = certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - switch certs[0].PublicKey.(type) { - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - break - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) - } - - c.peerCertificates = certs - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS -// <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { - cri := &certificateRequestInfo{ - AcceptableCAs: certReq.certificateAuthorities, - Version: vers, - } - - var rsaAvail, ecAvail bool - for _, certType := range certReq.certificateTypes { - switch certType { - case certTypeRSASign: - rsaAvail = true - case certTypeECDSASign: - ecAvail = true - } - } - - if !certReq.hasSignatureAlgorithm { - // Prior to TLS 1.2, signature schemes did not exist. In this case we - // make up a list based on the acceptable certificate types, to help - // GetClientCertificate and SupportsCertificate select the right certificate. - // The hash part of the SignatureScheme is a lie here, because - // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. - switch { - case rsaAvail && ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case rsaAvail: - cri.SignatureSchemes = []SignatureScheme{ - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - } - } - return toCertificateRequestInfo(cri) - } - - // Filter the signature schemes based on the certificate types. - // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). - cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) - for _, sigScheme := range certReq.supportedSignatureAlgorithms { - sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) - if err != nil { - continue - } - switch sigType { - case signatureECDSA, signatureEd25519: - if ecAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - case signatureRSAPSS, signaturePKCS1v15: - if rsaAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - } - } - - return toCertificateRequestInfo(cri) -} - -func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { - if c.config.GetClientCertificate != nil { - return c.config.GetClientCertificate(cri) - } - - for _, chain := range c.config.Certificates { - if err := cri.SupportsCertificate(&chain); err != nil { - continue - } - return &chain, nil - } - - // No acceptable certificate found. Don't send a certificate. - return new(Certificate), nil -} - -const clientSessionCacheKeyPrefix = "qtls-" - -// clientSessionCacheKey returns a key used to cache sessionTickets that could -// be used to resume previously negotiated TLS sessions with a server. -func clientSessionCacheKey(serverAddr net.Addr, config *config) string { - if len(config.ServerName) > 0 { - return clientSessionCacheKeyPrefix + config.ServerName - } - return clientSessionCacheKeyPrefix + serverAddr.String() -} - -// mutualProtocol finds the mutual ALPN protocol given list of possible -// protocols and a list of the preference order. -func mutualProtocol(protos, preferenceProtos []string) string { - for _, s := range preferenceProtos { - for _, c := range protos { - if s == c { - return s - } - } - } - return "" -} - -// hostnameInSNI converts name into an appropriate hostname for SNI. -// Literal IP addresses and absolute FQDNs are not permitted as SNI values. -// See RFC 6066, Section 3. -func hostnameInSNI(name string) string { - host := name - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] - } - if i := strings.LastIndex(host, "%"); i > 0 { - host = host[:i] - } - if net.ParseIP(host) != nil { - return "" - } - for len(name) > 0 && name[len(name)-1] == '.' { - name = name[:len(name)-1] - } - return name -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go deleted file mode 100644 index fb70cec0663..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_client_tls13.go +++ /dev/null @@ -1,740 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/hmac" - "crypto/rsa" - "encoding/binary" - "errors" - "fmt" - "hash" - "sync/atomic" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheParams ecdheParameters - - session *clientSessionState - earlySecret []byte - binderKey []byte - - certReq *certificateRequestMsgTLS13 - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 -} - -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, -// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. -func (hs *clientHandshakeStateTLS13) handshake() error { - c := hs.c - - // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, - // sections 4.1.2 and 4.1.3. - if c.handshakes > 0 { - c.sendAlert(alertProtocolVersion) - return errors.New("tls: server selected TLS 1.3 in a renegotiation") - } - - // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { - return c.sendAlert(alertInternalError) - } - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.processHelloRetryRequest(); err != nil { - return err - } - } - - hs.transcript.Write(hs.serverHello.marshal()) - - c.buffering = true - if err := hs.processServerHello(); err != nil { - return err - } - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.establishHandshakeKeys(); err != nil { - return err - } - if err := hs.readServerParameters(); err != nil { - return err - } - if err := hs.readServerCertificate(); err != nil { - return err - } - if err := hs.readServerFinished(); err != nil { - return err - } - if err := hs.sendClientCertificate(); err != nil { - return err - } - if err := hs.sendClientFinished(); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -// checkServerHelloOrHRR does validity checks that apply to both ServerHello and -// HelloRetryRequest messages. It sets hs.suite. -func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { - c := hs.c - - if hs.serverHello.supportedVersion == 0 { - c.sendAlert(alertMissingExtension) - return errors.New("tls: server selected TLS 1.3 using the legacy version field") - } - - if hs.serverHello.supportedVersion != VersionTLS13 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid version after a HelloRetryRequest") - } - - if hs.serverHello.vers != VersionTLS12 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an incorrect legacy version") - } - - if hs.serverHello.ocspStapling || - hs.serverHello.ticketSupported || - hs.serverHello.secureRenegotiationSupported || - len(hs.serverHello.secureRenegotiation) != 0 || - len(hs.serverHello.alpnProtocol) != 0 || - len(hs.serverHello.scts) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") - } - - if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not echo the legacy session ID") - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported compression format") - } - - selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) - if hs.suite != nil && selectedSuite != hs.suite { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server changed cipher suite after a HelloRetryRequest") - } - if selectedSuite == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server chose an unconfigured cipher suite") - } - hs.suite = selectedSuite - c.cipherSuite = hs.suite.id - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and -// resends hs.hello, and reads the new ServerHello into hs.serverHello. -func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. (The idea is that the server might offload transcript - // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) - - // The only HelloRetryRequest extensions we support are key_share and - // cookie, and clients must abort the handshake if the HRR would not result - // in any change in the ClientHello. - if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest message") - } - - if hs.serverHello.cookie != nil { - hs.hello.cookie = hs.serverHello.cookie - } - - if hs.serverHello.serverShare.group != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received malformed key_share extension") - } - - // If the server sent a key_share extension selecting a group, ensure it's - // a group we advertised but did not send a key share for, and send a key - // share for it this time. - if curveID := hs.serverHello.selectedGroup; curveID != 0 { - curveOK := false - for _, id := range hs.hello.supportedCurves { - if id == curveID { - curveOK = true - break - } - } - if !curveOK { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - if hs.ecdheParams.CurveID() == curveID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") - } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.ecdheParams = params - hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} - } - - hs.hello.raw = nil - if len(hs.hello.pskIdentities) > 0 { - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash == hs.suite.hash { - // Update binders and obfuscated_ticket_age. - ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) - hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd - - transcript := hs.suite.hash.New() - transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) - pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) - } else { - // Server selected a cipher suite incompatible with the PSK. - hs.hello.pskIdentities = nil - hs.hello.pskBinders = nil - } - } - - if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { - c.extraConfig.Rejected0RTT() - } - hs.hello.earlyData = false // disable 0-RTT - - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - hs.serverHello = serverHello - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) processServerHello() error { - c := hs.c - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: server sent two HelloRetryRequest messages") - } - - if len(hs.serverHello.cookie) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a cookie in a normal ServerHello") - } - - if hs.serverHello.selectedGroup != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: malformed key_share extension") - } - - if hs.serverHello.serverShare.group == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not send a key share") - } - if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - - if !hs.serverHello.selectedIdentityPresent { - return nil - } - - if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK") - } - - if len(hs.hello.pskIdentities) != 1 || hs.session == nil { - return c.sendAlert(alertInternalError) - } - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash != hs.suite.hash { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK and cipher suite pair") - } - - hs.usingPSK = true - c.didResume = true - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - c.scts = hs.session.scts - return nil -} - -func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { - c := hs.c - - sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) - if sharedKey == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - - earlySecret := hs.earlySecret - if !hs.usingPSK { - earlySecret = hs.suite.extract(nil, nil) - } - handshakeSecret := hs.suite.extract(sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) - c.out.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(handshakeSecret, "derived", nil)) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerParameters() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(encryptedExtensions, msg) - } - // Notify the caller if 0-RTT was rejected. - if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { - c.extraConfig.Rejected0RTT() - } - c.used0RTT = encryptedExtensions.earlyData - if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { - hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) - } - hs.transcript.Write(encryptedExtensions.marshal()) - - if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { - if len(encryptedExtensions.alpnProtocol) == 0 { - // the server didn't select an ALPN - c.sendAlert(alertNoApplicationProtocol) - return errors.New("ALPN negotiation failed. Server didn't offer any protocols") - } - if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.c.config.NextProtos) == "" { - // the protocol selected by the server was not offered - c.sendAlert(alertNoApplicationProtocol) - return fmt.Errorf("ALPN negotiation failed. Server offered: %q", encryptedExtensions.alpnProtocol) - } - } - if encryptedExtensions.alpnProtocol != "" { - if len(hs.hello.alpnProtocols) == 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server advertised unrequested ALPN extension") - } - if mutualProtocol([]string{encryptedExtensions.alpnProtocol}, hs.hello.alpnProtocols) == "" { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server selected unadvertised ALPN protocol") - } - c.clientProtocol = encryptedExtensions.alpnProtocol - } - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerCertificate() error { - c := hs.c - - // Either a PSK or a certificate is always used, but not both. - // See RFC 8446, Section 4.1.1. - if hs.usingPSK { - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - certReq, ok := msg.(*certificateRequestMsgTLS13) - if ok { - hs.transcript.Write(certReq.marshal()) - - hs.certReq = certReq - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - if len(certMsg.certificate.Certificate) == 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received empty certificates message") - } - hs.transcript.Write(certMsg.marshal()) - - c.scts = certMsg.certificate.SignedCertificateTimestamps - c.ocspResponse = certMsg.certificate.OCSPStaple - - if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - - hs.transcript.Write(certVerify.marshal()) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerFinished() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - if !hmac.Equal(expectedMAC, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid server finished hash") - } - - hs.transcript.Write(finished.marshal()) - - // Derive secrets that take context through the server Finished. - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { - c := hs.c - - if hs.certReq == nil { - return nil - } - - cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ - AcceptableCAs: hs.certReq.certificateAuthorities, - SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, - Version: c.vers, - })) - if err != nil { - return err - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *cert - certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - // If we sent an empty certificate message, skip the CertificateVerify. - if len(cert.Certificate) == 0 { - return nil - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - - certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) - if err != nil { - // getClientCertificate returned a certificate incompatible with the - // CertificateRequestInfo supported signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) - c.out.setTrafficSecret(hs.suite, hs.trafficSecret) - - if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - } - - return nil -} - -func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { - if !c.isClient { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received new session ticket from a client") - } - - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return nil - } - - // See RFC 8446, Section 4.6.1. - if msg.lifetime == 0 { - return nil - } - lifetime := time.Duration(msg.lifetime) * time.Second - if lifetime > maxSessionTicketLifetime { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: received a session ticket with invalid lifetime") - } - - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil || c.resumptionSecret == nil { - return c.sendAlert(alertInternalError) - } - - // We need to save the max_early_data_size that the server sent us, in order - // to decide if we're going to try 0-RTT with this ticket. - // However, at the same time, the qtls.ClientSessionTicket needs to be equal to - // the tls.ClientSessionTicket, so we can't just add a new field to the struct. - // We therefore abuse the nonce field (which is a byte slice) - nonceWithEarlyData := make([]byte, len(msg.nonce)+4) - binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) - copy(nonceWithEarlyData[4:], msg.nonce) - - var appData []byte - if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { - appData = c.extraConfig.GetAppDataForSessionState() - } - var b cryptobyte.Builder - b.AddUint16(clientSessionStateVersion) // revision - b.AddUint32(msg.maxEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(appData) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(msg.nonce) - }) - - // Save the resumption_master_secret and nonce instead of deriving the PSK - // to do the least amount of work on NewSessionTicket messages before we - // know if the ticket will be used. Forward secrecy of resumed connections - // is guaranteed by the requirement for pskModeDHE. - session := &clientSessionState{ - sessionTicket: msg.label, - vers: c.vers, - cipherSuite: c.cipherSuite, - masterSecret: c.resumptionSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - nonce: b.BytesOrPanic(), - useBy: c.config.time().Add(lifetime), - ageAdd: msg.ageAdd, - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) - - return nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go deleted file mode 100644 index 1ab75762637..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_messages.go +++ /dev/null @@ -1,1832 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "fmt" - "strings" - - "golang.org/x/crypto/cryptobyte" -) - -// The marshalingFunction type is an adapter to allow the use of ordinary -// functions as cryptobyte.MarshalingValue. -type marshalingFunction func(b *cryptobyte.Builder) error - -func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { - return f(b) -} - -// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If -// the length of the sequence is not the value specified, it produces an error. -func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { - b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { - if len(v) != n { - return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) - } - b.AddBytes(v) - return nil - })) -} - -// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. -func addUint64(b *cryptobyte.Builder, v uint64) { - b.AddUint32(uint32(v >> 32)) - b.AddUint32(uint32(v)) -} - -// readUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func readUint64(s *cryptobyte.String, out *uint64) bool { - var hi, lo uint32 - if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { - return false - } - *out = uint64(hi)<<32 | uint64(lo) - return true -} - -// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) -} - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []keyShare - earlyData bool - pskModes []uint8 - pskIdentities []pskIdentity - pskBinders [][]byte - additionalExtensions []Extension -} - -func (m *clientHelloMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeClientHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, suite := range m.cipherSuites { - b.AddUint16(suite) - } - }) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.compressionMethods) - }) - - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -// marshalWithoutBinders returns the ClientHello through the -// PreSharedKeyExtension.identities field, according to RFC 8446, Section -// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { - bindersLen := 2 // uint16 length prefix - for _, binder := range m.pskBinders { - bindersLen += 1 // uint8 length prefix - bindersLen += len(binder) - } - - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] -} - -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { - if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") - } - for i := range m.pskBinders { - if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") - } - } - m.pskBinders = pskBinders - if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) - // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. - b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if len(b.BytesOrPanic()) != len(m.raw) { - panic("tls: internal error: failed to update binders") - } - } -} - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) { - return false - } - - var cipherSuites cryptobyte.String - if !s.ReadUint16LengthPrefixed(&cipherSuites) { - return false - } - m.cipherSuites = []uint16{} - m.secureRenegotiationSupported = false - for !cipherSuites.Empty() { - var suite uint16 - if !cipherSuites.ReadUint16(&suite) { - return false - } - if suite == scsvRenegotiation { - m.secureRenegotiationSupported = true - } - m.cipherSuites = append(m.cipherSuites, suite) - } - - if !readUint8LengthPrefixed(&s, &m.compressionMethods) { - return false - } - - if s.Empty() { - // ClientHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var ext uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&ext) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch ext { - case extensionServerName: - // RFC 6066, Section 3 - var nameList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { - return false - } - for !nameList.Empty() { - var nameType uint8 - var serverName cryptobyte.String - if !nameList.ReadUint8(&nameType) || - !nameList.ReadUint16LengthPrefixed(&serverName) || - serverName.Empty() { - return false - } - if nameType != 0 { - continue - } - if len(m.serverName) != 0 { - // Multiple names of the same name_type are prohibited. - return false - } - m.serverName = string(serverName) - // An SNI value may not include a trailing dot. - if strings.HasSuffix(m.serverName, ".") { - return false - } - } - case extensionStatusRequest: - // RFC 4366, Section 3.6 - var statusType uint8 - var ignored cryptobyte.String - if !extData.ReadUint8(&statusType) || - !extData.ReadUint16LengthPrefixed(&ignored) || - !extData.ReadUint16LengthPrefixed(&ignored) { - return false - } - m.ocspStapling = statusType == statusTypeOCSP - case extensionSupportedCurves: - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - var curves cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { - return false - } - for !curves.Empty() { - var curve uint16 - if !curves.ReadUint16(&curve) { - return false - } - m.supportedCurves = append(m.supportedCurves, CurveID(curve)) - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - case extensionSessionTicket: - // RFC 5077, Section 3.2 - m.ticketSupported = true - extData.ReadBytes(&m.sessionTicket, len(extData)) - case extensionSignatureAlgorithms: - // RFC 5246, Section 7.4.1.4.1 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - // RFC 8446, Section 4.2.3 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionRenegotiationInfo: - // RFC 5746, Section 3.2 - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - // RFC 7301, Section 3.1 - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - for !protoList.Empty() { - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { - return false - } - m.alpnProtocols = append(m.alpnProtocols, string(proto)) - } - case extensionSCT: - // RFC 6962, Section 3.3.1 - m.scts = true - case extensionSupportedVersions: - // RFC 8446, Section 4.2.1 - var versList cryptobyte.String - if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { - return false - } - for !versList.Empty() { - var vers uint16 - if !versList.ReadUint16(&vers) { - return false - } - m.supportedVersions = append(m.supportedVersions, vers) - } - case extensionCookie: - // RFC 8446, Section 4.2.2 - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // RFC 8446, Section 4.2.8 - var clientShares cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&clientShares) { - return false - } - for !clientShares.Empty() { - var ks keyShare - if !clientShares.ReadUint16((*uint16)(&ks.group)) || - !readUint16LengthPrefixed(&clientShares, &ks.data) || - len(ks.data) == 0 { - return false - } - m.keyShares = append(m.keyShares, ks) - } - case extensionEarlyData: - // RFC 8446, Section 4.2.10 - m.earlyData = true - case extensionPSKModes: - // RFC 8446, Section 4.2.9 - if !readUint8LengthPrefixed(&extData, &m.pskModes) { - return false - } - case extensionPreSharedKey: - // RFC 8446, Section 4.2.11 - if !extensions.Empty() { - return false // pre_shared_key must be the last extension - } - var identities cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { - return false - } - for !identities.Empty() { - var psk pskIdentity - if !readUint16LengthPrefixed(&identities, &psk.label) || - !identities.ReadUint32(&psk.obfuscatedTicketAge) || - len(psk.label) == 0 { - return false - } - m.pskIdentities = append(m.pskIdentities, psk) - } - var binders cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { - return false - } - for !binders.Empty() { - var binder []byte - if !readUint8LengthPrefixed(&binders, &binder) || - len(binder) == 0 { - return false - } - m.pskBinders = append(m.pskBinders, binder) - } - default: - m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuite uint16 - compressionMethod uint8 - ocspStapling bool - ticketSupported bool - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocol string - scts [][]byte - supportedVersion uint16 - serverShare keyShare - selectedIdentityPresent bool - selectedIdentity uint16 - supportedPoints []uint8 - - // HelloRetryRequest extensions - cookie []byte - selectedGroup CurveID -} - -func (m *serverHelloMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeServerHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16(m.cipherSuite) - b.AddUint8(m.compressionMethod) - - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) || - !s.ReadUint16(&m.cipherSuite) || - !s.ReadUint8(&m.compressionMethod) { - return false - } - - if s.Empty() { - // ServerHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSessionTicket: - m.ticketSupported = true - case extensionRenegotiationInfo: - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - m.scts = append(m.scts, sct) - } - case extensionSupportedVersions: - if !extData.ReadUint16(&m.supportedVersion) { - return false - } - case extensionCookie: - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // This extension has different formats in SH and HRR, accept either - // and let the handshake logic decide. See RFC 8446, Section 4.2.8. - if len(extData) == 2 { - if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { - return false - } - } else { - if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || - !readUint16LengthPrefixed(&extData, &m.serverShare.data) { - return false - } - } - case extensionPreSharedKey: - m.selectedIdentityPresent = true - if !extData.ReadUint16(&m.selectedIdentity) { - return false - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type encryptedExtensionsMsg struct { - raw []byte - alpnProtocol string - earlyData bool - - additionalExtensions []Extension -} - -func (m *encryptedExtensionsMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeEncryptedExtensions) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var ext uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&ext) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch ext { - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionEarlyData: - m.earlyData = true - default: - m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type endOfEarlyDataMsg struct{} - -func (m *endOfEarlyDataMsg) marshal() []byte { - x := make([]byte, 4) - x[0] = typeEndOfEarlyData - return x -} - -func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type keyUpdateMsg struct { - raw []byte - updateRequested bool -} - -func (m *keyUpdateMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeKeyUpdate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.updateRequested { - b.AddUint8(1) - } else { - b.AddUint8(0) - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var updateRequested uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&updateRequested) || !s.Empty() { - return false - } - switch updateRequested { - case 0: - m.updateRequested = false - case 1: - m.updateRequested = true - default: - return false - } - return true -} - -type newSessionTicketMsgTLS13 struct { - raw []byte - lifetime uint32 - ageAdd uint32 - nonce []byte - label []byte - maxEarlyData uint32 -} - -func (m *newSessionTicketMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeNewSessionTicket) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.lifetime) - b.AddUint32(m.ageAdd) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.nonce) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.label) - }) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.maxEarlyData > 0 { - b.AddUint16(extensionEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.maxEarlyData) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint32(&m.lifetime) || - !s.ReadUint32(&m.ageAdd) || - !readUint8LengthPrefixed(&s, &m.nonce) || - !readUint16LengthPrefixed(&s, &m.label) || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionEarlyData: - if !extData.ReadUint32(&m.maxEarlyData) { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateRequestMsgTLS13 struct { - raw []byte - ocspStapling bool - scts bool - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateRequest) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - // certificate_request_context (SHALL be zero length unless used for - // post-handshake authentication) - b.AddUint8(0) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.scts { - // RFC 8446, Section 4.4.2.1 makes no mention of - // signed_certificate_timestamp in CertificateRequest, but - // "Extensions in the Certificate message from the client MUST - // correspond to extensions in the CertificateRequest message - // from the server." and it appears in the table in Section 4.2. - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedSignatureAlgorithms) > 0 { - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.certificateAuthorities) > 0 { - b.AddUint16(extensionCertificateAuthorities) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ca := range m.certificateAuthorities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ca) - }) - } - }) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context, extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSCT: - m.scts = true - case extensionSignatureAlgorithms: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionCertificateAuthorities: - var auths cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { - return false - } - for !auths.Empty() { - var ca []byte - if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { - return false - } - m.certificateAuthorities = append(m.certificateAuthorities, ca) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateMsg struct { - raw []byte - certificates [][]byte -} - -func (m *certificateMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - var i int - for _, slice := range m.certificates { - i += len(slice) - } - - length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) - x[0] = typeCertificate - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - certificateOctets := length - 3 - x[4] = uint8(certificateOctets >> 16) - x[5] = uint8(certificateOctets >> 8) - x[6] = uint8(certificateOctets) - - y := x[7:] - for _, slice := range m.certificates { - y[0] = uint8(len(slice) >> 16) - y[1] = uint8(len(slice) >> 8) - y[2] = uint8(len(slice)) - copy(y[3:], slice) - y = y[3+len(slice):] - } - - m.raw = x - return -} - -func (m *certificateMsg) unmarshal(data []byte) bool { - if len(data) < 7 { - return false - } - - m.raw = data - certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) - if uint32(len(data)) != certsLen+7 { - return false - } - - numCerts := 0 - d := data[7:] - for certsLen > 0 { - if len(d) < 4 { - return false - } - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - if uint32(len(d)) < 3+certLen { - return false - } - d = d[3+certLen:] - certsLen -= 3 + certLen - numCerts++ - } - - m.certificates = make([][]byte, numCerts) - d = data[7:] - for i := 0; i < numCerts; i++ { - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - m.certificates[i] = d[3 : 3+certLen] - d = d[3+certLen:] - } - - return true -} - -type certificateMsgTLS13 struct { - raw []byte - certificate Certificate - ocspStapling bool - scts bool -} - -func (m *certificateMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // certificate_request_context - - certificate := m.certificate - if !m.ocspStapling { - certificate.OCSPStaple = nil - } - if !m.scts { - certificate.SignedCertificateTimestamps = nil - } - marshalCertificate(b, certificate) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for i, cert := range certificate.Certificate { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if i > 0 { - // This library only supports OCSP and SCT for leaf certificates. - return - } - if certificate.OCSPStaple != nil { - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(certificate.OCSPStaple) - }) - }) - } - if certificate.SignedCertificateTimestamps != nil { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range certificate.SignedCertificateTimestamps { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - }) - } - }) -} - -func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !unmarshalCertificate(&s, &m.certificate) || - !s.Empty() { - return false - } - - m.scts = m.certificate.SignedCertificateTimestamps != nil - m.ocspStapling = m.certificate.OCSPStaple != nil - - return true -} - -func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - var extensions cryptobyte.String - if !readUint24LengthPrefixed(&certList, &cert) || - !certList.ReadUint16LengthPrefixed(&extensions) { - return false - } - certificate.Certificate = append(certificate.Certificate, cert) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - if len(certificate.Certificate) > 1 { - // This library only supports OCSP and SCT for leaf certificates. - continue - } - - switch extension { - case extensionStatusRequest: - var statusType uint8 - if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || - len(certificate.OCSPStaple) == 0 { - return false - } - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - certificate.SignedCertificateTimestamps = append( - certificate.SignedCertificateTimestamps, sct) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - } - return true -} - -type serverKeyExchangeMsg struct { - raw []byte - key []byte -} - -func (m *serverKeyExchangeMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - length := len(m.key) - x := make([]byte, length+4) - x[0] = typeServerKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.key) - - m.raw = x - return x -} - -func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - m.key = data[4:] - return true -} - -type certificateStatusMsg struct { - raw []byte - response []byte -} - -func (m *certificateStatusMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateStatus) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.response) - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var statusType uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&s, &m.response) || - len(m.response) == 0 || !s.Empty() { - return false - } - return true -} - -type serverHelloDoneMsg struct{} - -func (m *serverHelloDoneMsg) marshal() []byte { - x := make([]byte, 4) - x[0] = typeServerHelloDone - return x -} - -func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type clientKeyExchangeMsg struct { - raw []byte - ciphertext []byte -} - -func (m *clientKeyExchangeMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - length := len(m.ciphertext) - x := make([]byte, length+4) - x[0] = typeClientKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.ciphertext) - - m.raw = x - return x -} - -func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if l != len(data)-4 { - return false - } - m.ciphertext = data[4:] - return true -} - -type finishedMsg struct { - raw []byte - verifyData []byte -} - -func (m *finishedMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeFinished) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.verifyData) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - return s.Skip(1) && - readUint24LengthPrefixed(&s, &m.verifyData) && - s.Empty() -} - -type certificateRequestMsg struct { - raw []byte - // hasSignatureAlgorithm indicates whether this message includes a list of - // supported signature algorithms. This change was introduced with TLS 1.2. - hasSignatureAlgorithm bool - - certificateTypes []byte - supportedSignatureAlgorithms []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - // See RFC 4346, Section 7.4.4. - length := 1 + len(m.certificateTypes) + 2 - casLength := 0 - for _, ca := range m.certificateAuthorities { - casLength += 2 + len(ca) - } - length += casLength - - if m.hasSignatureAlgorithm { - length += 2 + 2*len(m.supportedSignatureAlgorithms) - } - - x = make([]byte, 4+length) - x[0] = typeCertificateRequest - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - x[4] = uint8(len(m.certificateTypes)) - - copy(x[5:], m.certificateTypes) - y := x[5+len(m.certificateTypes):] - - if m.hasSignatureAlgorithm { - n := len(m.supportedSignatureAlgorithms) * 2 - y[0] = uint8(n >> 8) - y[1] = uint8(n) - y = y[2:] - for _, sigAlgo := range m.supportedSignatureAlgorithms { - y[0] = uint8(sigAlgo >> 8) - y[1] = uint8(sigAlgo) - y = y[2:] - } - } - - y[0] = uint8(casLength >> 8) - y[1] = uint8(casLength) - y = y[2:] - for _, ca := range m.certificateAuthorities { - y[0] = uint8(len(ca) >> 8) - y[1] = uint8(len(ca)) - y = y[2:] - copy(y, ca) - y = y[len(ca):] - } - - m.raw = x - return -} - -func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 5 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - numCertTypes := int(data[4]) - data = data[5:] - if numCertTypes == 0 || len(data) <= numCertTypes { - return false - } - - m.certificateTypes = make([]byte, numCertTypes) - if copy(m.certificateTypes, data) != numCertTypes { - return false - } - - data = data[numCertTypes:] - - if m.hasSignatureAlgorithm { - if len(data) < 2 { - return false - } - sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if sigAndHashLen&1 != 0 { - return false - } - if len(data) < int(sigAndHashLen) { - return false - } - numSigAlgos := sigAndHashLen / 2 - m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) - for i := range m.supportedSignatureAlgorithms { - m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) - data = data[2:] - } - } - - if len(data) < 2 { - return false - } - casLength := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if len(data) < int(casLength) { - return false - } - cas := make([]byte, casLength) - copy(cas, data) - data = data[casLength:] - - m.certificateAuthorities = nil - for len(cas) > 0 { - if len(cas) < 2 { - return false - } - caLen := uint16(cas[0])<<8 | uint16(cas[1]) - cas = cas[2:] - - if len(cas) < int(caLen) { - return false - } - - m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) - cas = cas[caLen:] - } - - return len(data) == 0 -} - -type certificateVerifyMsg struct { - raw []byte - hasSignatureAlgorithm bool // format change introduced in TLS 1.2 - signatureAlgorithm SignatureScheme - signature []byte -} - -func (m *certificateVerifyMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateVerify) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.hasSignatureAlgorithm { - b.AddUint16(uint16(m.signatureAlgorithm)) - } - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.signature) - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - if !s.Skip(4) { // message type and uint24 length field - return false - } - if m.hasSignatureAlgorithm { - if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { - return false - } - } - return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() -} - -type newSessionTicketMsg struct { - raw []byte - ticket []byte -} - -func (m *newSessionTicketMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - // See RFC 5077, Section 3.3. - ticketLen := len(m.ticket) - length := 2 + 4 + ticketLen - x = make([]byte, 4+length) - x[0] = typeNewSessionTicket - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[8] = uint8(ticketLen >> 8) - x[9] = uint8(ticketLen) - copy(x[10:], m.ticket) - - m.raw = x - - return -} - -func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 10 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - ticketLen := int(data[8])<<8 + int(data[9]) - if len(data)-10 != ticketLen { - return false - } - - m.ticket = data[10:] - - return true -} - -type helloRequestMsg struct { -} - -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} -} - -func (*helloRequestMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go deleted file mode 100644 index 5d39cc8fcd4..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server.go +++ /dev/null @@ -1,878 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "sync/atomic" - "time" -) - -// serverHandshakeState contains details of a server handshake in progress. -// It's discarded once the handshake has completed. -type serverHandshakeState struct { - c *Conn - clientHello *clientHelloMsg - hello *serverHelloMsg - suite *cipherSuite - ecdheOk bool - ecSignOk bool - rsaDecryptOk bool - rsaSignOk bool - sessionState *sessionState - finishedHash finishedHash - masterSecret []byte - cert *Certificate -} - -// serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake() error { - c.setAlternativeRecordLayer() - - clientHello, err := c.readClientHello() - if err != nil { - return err - } - - if c.vers == VersionTLS13 { - hs := serverHandshakeStateTLS13{ - c: c, - clientHello: clientHello, - } - return hs.handshake() - } else if c.extraConfig.usesAlternativeRecordLayer() { - // This should already have been caught by the check that the ClientHello doesn't - // offer any (supported) versions older than TLS 1.3. - // Check again to make sure we can't be tricked into using an older version. - c.sendAlert(alertProtocolVersion) - return errors.New("tls: negotiated TLS < 1.3 when using QUIC") - } - - hs := serverHandshakeState{ - c: c, - clientHello: clientHello, - } - return hs.handshake() -} - -func (hs *serverHandshakeState) handshake() error { - c := hs.c - - if err := hs.processClientHello(); err != nil { - return err - } - - // For an overview of TLS handshaking, see RFC 5246, Section 7.3. - c.buffering = true - if hs.checkForResumption() { - // The client has included a session ticket and so we do an abbreviated handshake. - c.didResume = true - if err := hs.doResumeHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(c.serverFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = false - if err := hs.readFinished(nil); err != nil { - return err - } - } else { - // The client didn't include a session ticket, or it wasn't - // valid so we do a full handshake. - if err := hs.pickCipherSuite(); err != nil { - return err - } - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readFinished(c.clientFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = true - c.buffering = true - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(nil); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -// readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello() (*clientHelloMsg, error) { - msg, err := c.readHandshake() - if err != nil { - return nil, err - } - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return nil, unexpectedMessageError(clientHello, msg) - } - - var configForClient *config - originalConfig := c.config - if c.config.GetConfigForClient != nil { - chi := newClientHelloInfo(c, clientHello) - if cfc, err := c.config.GetConfigForClient(chi); err != nil { - c.sendAlert(alertInternalError) - return nil, err - } else if cfc != nil { - configForClient = fromConfig(cfc) - c.config = configForClient - } - } - c.ticketKeys = originalConfig.ticketKeys(configForClient) - - clientVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - clientVersions = supportedVersionsFromMax(clientHello.vers) - } - if c.extraConfig.usesAlternativeRecordLayer() { - // In QUIC, the client MUST NOT offer any old TLS versions. - // Here, we can only check that none of the other supported versions of this library - // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. - for _, ver := range clientVersions { - if ver == VersionTLS13 { - continue - } - for _, v := range supportedVersions { - if ver == v { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) - } - } - } - // Make the config we're using allows us to use TLS 1.3. - if c.config.maxSupportedVersion() < VersionTLS13 { - c.sendAlert(alertInternalError) - return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") - } - } - c.vers, ok = c.config.mutualVersion(clientVersions) - if !ok { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) - } - c.haveVers = true - c.in.version = c.vers - c.out.version = c.vers - - return clientHello, nil -} - -func (hs *serverHandshakeState) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.hello.vers = c.vers - - foundCompression := false - // We only support null compression, so check that the client offered it. - for _, compression := range hs.clientHello.compressionMethods { - if compression == compressionNone { - foundCompression = true - break - } - } - - if !foundCompression { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client does not support uncompressed connections") - } - - hs.hello.random = make([]byte, 32) - serverRandom := hs.hello.random - // Downgrade protection canaries. See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() - if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { - if c.vers == VersionTLS12 { - copy(serverRandom[24:], downgradeCanaryTLS12) - } else { - copy(serverRandom[24:], downgradeCanaryTLS11) - } - serverRandom = serverRandom[:24] - } - _, err := io.ReadFull(c.config.rand(), serverRandom) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported - hs.hello.compressionMethod = compressionNone - if len(hs.clientHello.serverName) > 0 { - c.serverName = hs.clientHello.serverName - } - - if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { - hs.hello.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - } - } - - hs.cert, err = c.config.getCertificate(newClientHelloInfo(c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - if hs.clientHello.scts { - hs.hello.scts = hs.cert.SignedCertificateTimestamps - } - - hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) - - if hs.ecdheOk { - // Although omitting the ec_point_formats extension is permitted, some - // old OpenSSL version will refuse to handshake if not present. - // - // Per RFC 4492, section 5.1.2, implementations MUST support the - // uncompressed point format. See golang.org/issue/31943. - hs.hello.supportedPoints = []uint8{pointFormatUncompressed} - } - - if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { - switch priv.Public().(type) { - case *ecdsa.PublicKey: - hs.ecSignOk = true - case ed25519.PublicKey: - hs.ecSignOk = true - case *rsa.PublicKey: - hs.rsaSignOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) - } - } - if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { - switch priv.Public().(type) { - case *rsa.PublicKey: - hs.rsaDecryptOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) - } - } - - return nil -} - -// supportsECDHE returns whether ECDHE key exchanges can be used with this -// pre-TLS 1.3 client. -func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { - supportsCurve := false - for _, curve := range supportedCurves { - if c.supportsCurve(curve) { - supportsCurve = true - break - } - } - - supportsPointFormat := false - for _, pointFormat := range supportedPoints { - if pointFormat == pointFormatUncompressed { - supportsPointFormat = true - break - } - } - - return supportsCurve && supportsPointFormat -} - -func (hs *serverHandshakeState) pickCipherSuite() error { - c := hs.c - - var preferenceList, supportedList []uint16 - if c.config.PreferServerCipherSuites { - preferenceList = c.config.cipherSuites() - supportedList = hs.clientHello.cipherSuites - - // If the client does not seem to have hardware support for AES-GCM, - // and the application did not specify a cipher suite preference order, - // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers - // by default. - if c.config.CipherSuites == nil && !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = deprioritizeAES(preferenceList) - } - } else { - preferenceList = hs.clientHello.cipherSuites - supportedList = c.config.cipherSuites() - - // If we don't have hardware support for AES-GCM, prefer other AEAD - // ciphers even if the client prioritized AES-GCM. - if !hasAESGCMHardwareSupport { - preferenceList = deprioritizeAES(preferenceList) - } - } - - hs.suite = selectCipherSuite(preferenceList, supportedList, hs.cipherSuiteOk) - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // The client is doing a fallback connection. See RFC 7507. - if hs.clientHello.vers < c.config.maxSupportedVersion() { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - return nil -} - -func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - if !hs.ecdheOk { - return false - } - if c.flags&suiteECSign != 0 { - if !hs.ecSignOk { - return false - } - } else if !hs.rsaSignOk { - return false - } - } else if !hs.rsaDecryptOk { - return false - } - if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true -} - -// checkForResumption reports whether we should perform resumption on this connection. -func (hs *serverHandshakeState) checkForResumption() bool { - c := hs.c - - if c.config.SessionTicketsDisabled { - return false - } - - plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) - if plaintext == nil { - return false - } - hs.sessionState = &sessionState{usedOldKey: usedOldKey} - ok := hs.sessionState.unmarshal(plaintext) - if !ok { - return false - } - - createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - return false - } - - // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.vers { - return false - } - - cipherSuiteOk := false - // Check that the client is still offering the ciphersuite in the session. - for _, id := range hs.clientHello.cipherSuites { - if id == hs.sessionState.cipherSuite { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return false - } - - // Check that we also support the ciphersuite from the session. - hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, - c.config.cipherSuites(), hs.cipherSuiteOk) - if hs.suite == nil { - return false - } - - sessionHasClientCerts := len(hs.sessionState.certificates) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - return false - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - return false - } - - return true -} - -func (hs *serverHandshakeState) doResumeHandshake() error { - c := hs.c - - hs.hello.cipherSuite = hs.suite.id - c.cipherSuite = hs.suite.id - // We echo the client's session ID in the ServerHello to let it know - // that we're doing a resumption. - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.ticketSupported = hs.sessionState.usedOldKey - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: hs.sessionState.certificates, - }); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - hs.masterSecret = hs.sessionState.masterSecret - - return nil -} - -func (hs *serverHandshakeState) doFullHandshake() error { - c := hs.c - - if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { - hs.hello.ocspStapling = true - } - - hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled - hs.hello.cipherSuite = hs.suite.id - - hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - if c.config.ClientAuth == NoClientCert { - // No need to keep a full record of the handshake if client - // certificates won't be used. - hs.finishedHash.discardHandshakeBuffer() - } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - certMsg := new(certificateMsg) - certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - if hs.hello.ocspStapling { - certStatus := new(certificateStatusMsg) - certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { - return err - } - } - - keyAgreement := hs.suite.ka(c.vers) - skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { - return err - } - } - - var certReq *certificateRequestMsg - if c.config.ClientAuth >= RequestClientCert { - // Request a client certificate - certReq = new(certificateRequestMsg) - certReq.certificateTypes = []byte{ - byte(certTypeRSASign), - byte(certTypeECDSASign), - } - if c.vers >= VersionTLS12 { - certReq.hasSignatureAlgorithm = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms - } - - // An empty list of certificateAuthorities signals to - // the client that it may send any certificate in response - // to our request. When we know the CAs we trust, then - // we can send them down, so that the client can choose - // an appropriate certificate to give to us. - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - hs.finishedHash.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { - return err - } - } - - helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { - return err - } - - if _, err := c.flush(); err != nil { - return err - } - - var pub crypto.PublicKey // public key for client auth, if any - - msg, err := c.readHandshake() - if err != nil { - return err - } - - // If we requested a client certificate, then the client must send a - // certificate message, even if it's empty. - if c.config.ClientAuth >= RequestClientCert { - certMsg, ok := msg.(*certificateMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.finishedHash.Write(certMsg.marshal()) - - if err := c.processCertsFromClient(Certificate{ - Certificate: certMsg.certificates, - }); err != nil { - return err - } - if len(certMsg.certificates) != 0 { - pub = c.peerCertificates[0].PublicKey - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - // Get client key exchange - ckx, ok := msg.(*clientKeyExchangeMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(ckx, msg) - } - hs.finishedHash.Write(ckx.marshal()) - - preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - - // If we received a client cert in response to our certificate request message, - // the client will send us a certificateVerifyMsg immediately after the - // clientKeyExchangeMsg. This message is a digest of all preceding - // handshake-layer messages that is signed using the private key corresponding - // to the client's certificate. This allows us to verify that the client is in - // possession of the private key of the certificate. - if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() - if err != nil { - return err - } - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) - if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - hs.finishedHash.Write(certVerify.marshal()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *serverHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - - if hs.suite.aead == nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) - c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) - - return nil -} - -func (hs *serverHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - clientFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientFinished, msg) - } - - verify := hs.finishedHash.clientSum(hs.masterSecret) - if len(verify) != len(clientFinished.verifyData) || - subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client's Finished message is incorrect") - } - - hs.finishedHash.Write(clientFinished.marshal()) - copy(out, verify) - return nil -} - -func (hs *serverHandshakeState) sendSessionTicket() error { - // ticketSupported is set in a resumption handshake if the - // ticket from the client was encrypted with an old session - // ticket key and thus a refreshed ticket should be sent. - if !hs.hello.ticketSupported { - return nil - } - - c := hs.c - m := new(newSessionTicketMsg) - - createdAt := uint64(c.config.time().Unix()) - if hs.sessionState != nil { - // If this is re-wrapping an old key, then keep - // the original time it was created. - createdAt = hs.sessionState.createdAt - } - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionState{ - vers: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - masterSecret: hs.masterSecret, - certificates: certsFromClient, - } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) - if err != nil { - return err - } - - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - copy(out, finished.verifyData) - - return nil -} - -// processCertsFromClient takes a chain of client certificates either from a -// Certificates message or from a sessionState and verifies them. It returns -// the public key of the leaf certificate. -func (c *Conn) processCertsFromClient(certificate Certificate) error { - certificates := certificate.Certificate - certs := make([]*x509.Certificate, len(certificates)) - var err error - for i, asn1Data := range certificates { - if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse client certificate: " + err.Error()) - } - } - - if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: client didn't provide a certificate") - } - - if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { - opts := x509.VerifyOptions{ - Roots: c.config.ClientCAs, - CurrentTime: c.config.time(), - Intermediates: x509.NewCertPool(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - - chains, err := certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to verify client certificate: " + err.Error()) - } - - c.verifiedChains = chains - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps - - if len(certs) > 0 { - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - } - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -func newClientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { - supportedVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - supportedVersions = supportedVersionsFromMax(clientHello.vers) - } - - return toClientHelloInfo(&clientHelloInfo{ - CipherSuites: clientHello.cipherSuites, - ServerName: clientHello.serverName, - SupportedCurves: clientHello.supportedCurves, - SupportedPoints: clientHello.supportedPoints, - SignatureSchemes: clientHello.supportedSignatureAlgorithms, - SupportedProtos: clientHello.alpnProtocols, - SupportedVersions: supportedVersions, - Conn: c.conn, - config: toConfig(c.config), - }) -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go deleted file mode 100644 index e1ab918d3ee..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/handshake_server_tls13.go +++ /dev/null @@ -1,912 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/hmac" - "crypto/rsa" - "errors" - "fmt" - "hash" - "io" - "sync/atomic" - "time" -) - -// maxClientPSKIdentities is the number of client PSK identities the server will -// attempt to validate. It will ignore the rest not to let cheap ClientHello -// messages cause too much work in session ticket decryption attempts. -const maxClientPSKIdentities = 5 - -type serverHandshakeStateTLS13 struct { - c *Conn - clientHello *clientHelloMsg - hello *serverHelloMsg - encryptedExtensions *encryptedExtensionsMsg - sentDummyCCS bool - usingPSK bool - suite *cipherSuiteTLS13 - cert *Certificate - sigAlg SignatureScheme - earlySecret []byte - sharedKey []byte - handshakeSecret []byte - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 - transcript hash.Hash - clientFinished []byte -} - -func (hs *serverHandshakeStateTLS13) handshake() error { - c := hs.c - - // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. - if err := hs.processClientHello(); err != nil { - return err - } - if err := hs.checkForResumption(); err != nil { - return err - } - if err := hs.pickCertificate(); err != nil { - return err - } - c.buffering = true - if err := hs.sendServerParameters(); err != nil { - return err - } - if err := hs.sendServerCertificate(); err != nil { - return err - } - if err := hs.sendServerFinished(); err != nil { - return err - } - // Note that at this point we could start sending application data without - // waiting for the client's second flight, but the application might not - // expect the lack of replay protection of the ClientHello parameters. - if _, err := c.flush(); err != nil { - return err - } - if err := hs.readClientCertificate(); err != nil { - return err - } - if err := hs.readClientFinished(); err != nil { - return err - } - - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -func (hs *serverHandshakeStateTLS13) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.encryptedExtensions = new(encryptedExtensionsMsg) - - // TLS 1.3 froze the ServerHello.legacy_version field, and uses - // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. - hs.hello.vers = VersionTLS12 - hs.hello.supportedVersion = c.vers - - if len(hs.clientHello.supportedVersions) == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") - } - - // Abort if the client is doing a fallback and landing lower than what we - // support. See RFC 7507, which however does not specify the interaction - // with supported_versions. The only difference is that with - // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] - // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, - // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to - // TLS 1.2, because a TLS 1.3 server would abort here. The situation before - // supported_versions was not better because there was just no way to do a - // TLS 1.4 handshake without risking the server selecting TLS 1.3. - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // Use c.vers instead of max(supported_versions) because an attacker - // could defeat this by adding an arbitrary high version otherwise. - if c.vers < c.config.maxSupportedVersion() { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - if len(hs.clientHello.compressionMethods) != 1 || - hs.clientHello.compressionMethods[0] != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: TLS 1.3 client supports illegal compression methods") - } - - hs.hello.random = make([]byte, 32) - if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.compressionMethod = compressionNone - - var preferenceList, supportedList, ourList []uint16 - var useConfiguredCipherSuites bool - for _, suiteID := range c.config.CipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - ourList = append(ourList, suiteID) - break - } - } - } - if len(ourList) > 0 { - useConfiguredCipherSuites = true - } else { - ourList = defaultCipherSuitesTLS13() - } - if c.config.PreferServerCipherSuites { - preferenceList = ourList - supportedList = hs.clientHello.cipherSuites - - // If the client does not seem to have hardware support for AES-GCM, - // prefer other AEAD ciphers even if we prioritized AES-GCM ciphers - // by default. - if !useConfiguredCipherSuites && !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = deprioritizeAES(preferenceList) - } - } else { - preferenceList = hs.clientHello.cipherSuites - supportedList = ourList - - // If we don't have hardware support for AES-GCM, prefer other AEAD - // ciphers even if the client prioritized AES-GCM. - if !hasAESGCMHardwareSupport { - preferenceList = deprioritizeAES(preferenceList) - } - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(supportedList, suiteID) - if hs.suite != nil { - break - } - } - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - hs.hello.cipherSuite = hs.suite.id - hs.transcript = hs.suite.hash.New() - - // Pick the ECDHE group in server preference order, but give priority to - // groups with a key share, to avoid a HelloRetryRequest round-trip. - var selectedGroup CurveID - var clientKeyShare *keyShare -GroupSelection: - for _, preferredGroup := range c.config.curvePreferences() { - for _, ks := range hs.clientHello.keyShares { - if ks.group == preferredGroup { - selectedGroup = ks.group - clientKeyShare = &ks - break GroupSelection - } - } - if selectedGroup != 0 { - continue - } - for _, group := range hs.clientHello.supportedCurves { - if group == preferredGroup { - selectedGroup = group - break - } - } - } - if selectedGroup == 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no ECDHE curve supported by both client and server") - } - if clientKeyShare == nil { - if err := hs.doHelloRetryRequest(selectedGroup); err != nil { - return err - } - clientKeyShare = &hs.clientHello.keyShares[0] - } - - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) - if hs.sharedKey == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - - c.serverName = hs.clientHello.serverName - - if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { - c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) - } - - if len(hs.clientHello.alpnProtocols) > 0 { - if selectedProto := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); selectedProto != "" { - hs.encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) checkForResumption() error { - c := hs.c - - if c.config.SessionTicketsDisabled { - return nil - } - - modeOK := false - for _, mode := range hs.clientHello.pskModes { - if mode == pskModeDHE { - modeOK = true - break - } - } - if !modeOK { - return nil - } - - if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid or missing PSK binders") - } - if len(hs.clientHello.pskIdentities) == 0 { - return nil - } - - for i, identity := range hs.clientHello.pskIdentities { - if i >= maxClientPSKIdentities { - break - } - - plaintext, _ := c.decryptTicket(identity.label) - if plaintext == nil { - continue - } - sessionState := new(sessionStateTLS13) - if ok := sessionState.unmarshal(plaintext); !ok { - continue - } - - if hs.clientHello.earlyData { - if sessionState.maxEarlyData == 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: client sent unexpected early data") - } - - if sessionState.alpn == c.clientProtocol && - c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && - c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { - hs.encryptedExtensions.earlyData = true - c.used0RTT = true - } - } - - createdAt := time.Unix(int64(sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - continue - } - - // We don't check the obfuscated ticket age because it's affected by - // clock skew and it's only a freshness signal useful for shrinking the - // window for replay attacks, which don't affect us as we don't do 0-RTT. - - pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) - if pskSuite == nil || pskSuite.hash != hs.suite.hash { - continue - } - - // PSK connections don't re-establish client certificates, but carry - // them over in the session ticket. Ensure the presence of client certs - // in the ticket is consistent with the configured requirements. - sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - continue - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - continue - } - - psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", - nil, hs.suite.hash.Size()) - hs.earlySecret = hs.suite.extract(psk, nil) - binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) - // Clone the transcript in case a HelloRetryRequest was recorded. - transcript := cloneHash(hs.transcript, hs.suite.hash) - if transcript == nil { - c.sendAlert(alertInternalError) - return errors.New("tls: internal error: failed to clone hash") - } - transcript.Write(hs.clientHello.marshalWithoutBinders()) - pskBinder := hs.suite.finishedHash(binderKey, transcript) - if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid PSK binder") - } - - c.didResume = true - if err := c.processCertsFromClient(sessionState.certificate); err != nil { - return err - } - - h := cloneHash(hs.transcript, hs.suite.hash) - h.Write(hs.clientHello.marshal()) - if hs.encryptedExtensions.earlyData { - clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) - c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) - if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - } - - hs.hello.selectedIdentityPresent = true - hs.hello.selectedIdentity = uint16(i) - hs.usingPSK = true - return nil - } - - return nil -} - -// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler -// interfaces implemented by standard library hashes to clone the state of in -// to a new instance of h. It returns nil if the operation fails. -func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { - // Recreate the interface to avoid importing encoding. - type binaryMarshaler interface { - MarshalBinary() (data []byte, err error) - UnmarshalBinary(data []byte) error - } - marshaler, ok := in.(binaryMarshaler) - if !ok { - return nil - } - state, err := marshaler.MarshalBinary() - if err != nil { - return nil - } - out := h.New() - unmarshaler, ok := out.(binaryMarshaler) - if !ok { - return nil - } - if err := unmarshaler.UnmarshalBinary(state); err != nil { - return nil - } - return out -} - -func (hs *serverHandshakeStateTLS13) pickCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. - if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { - return c.sendAlert(alertMissingExtension) - } - - certificate, err := c.config.getCertificate(newClientHelloInfo(c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) - if err != nil { - // getCertificate returned a certificate that is unsupported or - // incompatible with the client's signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - hs.cert = certificate - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - - helloRetryRequest := &serverHelloMsg{ - vers: hs.hello.vers, - random: helloRetryRequestRandom, - sessionId: hs.hello.sessionId, - cipherSuite: hs.hello.cipherSuite, - compressionMethod: hs.hello.compressionMethod, - supportedVersion: hs.hello.supportedVersion, - selectedGroup: selectedGroup, - } - - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientHello, msg) - } - - if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client sent invalid key share in second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client indicated early data in second ClientHello") - } - - if illegalClientHelloChange(clientHello, hs.clientHello) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client offered 0-RTT data in second ClientHello") - } - - hs.clientHello = clientHello - return nil -} - -// illegalClientHelloChange reports whether the two ClientHello messages are -// different, with the exception of the changes allowed before and after a -// HelloRetryRequest. See RFC 8446, Section 4.1.2. -func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { - if len(ch.supportedVersions) != len(ch1.supportedVersions) || - len(ch.cipherSuites) != len(ch1.cipherSuites) || - len(ch.supportedCurves) != len(ch1.supportedCurves) || - len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || - len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || - len(ch.alpnProtocols) != len(ch1.alpnProtocols) { - return true - } - for i := range ch.supportedVersions { - if ch.supportedVersions[i] != ch1.supportedVersions[i] { - return true - } - } - for i := range ch.cipherSuites { - if ch.cipherSuites[i] != ch1.cipherSuites[i] { - return true - } - } - for i := range ch.supportedCurves { - if ch.supportedCurves[i] != ch1.supportedCurves[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithms { - if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithmsCert { - if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { - return true - } - } - for i := range ch.alpnProtocols { - if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { - return true - } - } - return ch.vers != ch1.vers || - !bytes.Equal(ch.random, ch1.random) || - !bytes.Equal(ch.sessionId, ch1.sessionId) || - !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || - ch.serverName != ch1.serverName || - ch.ocspStapling != ch1.ocspStapling || - !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || - ch.ticketSupported != ch1.ticketSupported || - !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || - ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || - !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || - ch.scts != ch1.scts || - !bytes.Equal(ch.cookie, ch1.cookie) || - !bytes.Equal(ch.pskModes, ch1.pskModes) -} - -func (hs *serverHandshakeStateTLS13) sendServerParameters() error { - c := hs.c - - if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection && len(c.clientProtocol) == 0 { - c.sendAlert(alertNoApplicationProtocol) - return fmt.Errorf("ALPN negotiation failed. Client offered: %q", hs.clientHello.alpnProtocols) - } - - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - earlySecret := hs.earlySecret - if earlySecret == nil { - earlySecret = hs.suite.extract(nil, nil) - } - hs.handshakeSecret = hs.suite.extract(hs.sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) - c.in.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { - hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) - } - - hs.transcript.Write(hs.encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) requestClientCert() bool { - return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK -} - -func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - if hs.requestClientCert() { - // Request a client certificate - certReq := new(certificateRequestMsgTLS13) - certReq.ocspStapling = true - certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { - return err - } - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *hs.cert - certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - certVerifyMsg.signatureAlgorithm = hs.sigAlg - - sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - public := hs.cert.PrivateKey.(crypto.Signer).Public() - if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && - rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS - c.sendAlert(alertHandshakeFailure) - } else { - c.sendAlert(alertInternalError) - } - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) sendServerFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - // If we did not request client certificates, at this point we can - // precompute the client finished and roll the transcript forward to send - // session tickets in our first flight. - if !hs.requestClientCert() { - if err := hs.sendSessionTickets(); err != nil { - return err - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { - if hs.c.config.SessionTicketsDisabled { - return false - } - - // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. - for _, pskMode := range hs.clientHello.pskModes { - if pskMode == pskModeDHE { - return true - } - } - return false -} - -func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { - c := hs.c - - hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - finishedMsg := &finishedMsg{ - verifyData: hs.clientFinished, - } - hs.transcript.Write(finishedMsg.marshal()) - - if !hs.shouldSendSessionTickets() { - return nil - } - - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - - // Don't send session tickets when the alternative record layer is set. - // Instead, save the resumption secret on the Conn. - // Session tickets can then be generated by calling Conn.GetSessionTicket(). - if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { - return nil - } - - m, err := hs.c.getSessionTicketMsg(nil) - if err != nil { - return err - } - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientCertificate() error { - c := hs.c - - if !hs.requestClientCert() { - // Make sure the connection is still being verified whether or not - // the server requested a client certificate. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - // If we requested a client certificate, then the client must send a - // certificate message. If it's empty, no CertificateVerify is sent. - - msg, err := c.readHandshake() - if err != nil { - return err - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.transcript.Write(certMsg.marshal()) - - if err := c.processCertsFromClient(certMsg.certificate); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - hs.transcript.Write(certVerify.marshal()) - } - - // If we waited until the client certificates to send session tickets, we - // are ready to do it now. - if err := hs.sendSessionTickets(); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientFinished() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - if !hmac.Equal(hs.clientFinished, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid client finished hash") - } - - c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) - c.in.setTrafficSecret(hs.suite, hs.trafficSecret) - - return nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go deleted file mode 100644 index d8f5d4690fe..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/key_agreement.go +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/md5" - "crypto/rsa" - "crypto/sha1" - "crypto/x509" - "errors" - "fmt" - "io" -) - -var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") -var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") - -// rsaKeyAgreement implements the standard TLS key agreement where the client -// encrypts the pre-master secret to the server's public key. -type rsaKeyAgreement struct{} - -func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - return nil, nil -} - -func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) < 2 { - return nil, errClientKeyExchange - } - ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) - if ciphertextLen != len(ckx.ciphertext)-2 { - return nil, errClientKeyExchange - } - ciphertext := ckx.ciphertext[2:] - - priv, ok := cert.PrivateKey.(crypto.Decrypter) - if !ok { - return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") - } - // Perform constant time RSA PKCS #1 v1.5 decryption - preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) - if err != nil { - return nil, err - } - // We don't check the version number in the premaster secret. For one, - // by checking it, we would leak information about the validity of the - // encrypted pre-master secret. Secondly, it provides only a small - // benefit against a downgrade attack and some implementations send the - // wrong version anyway. See the discussion at the end of section - // 7.4.7.1 of RFC 4346. - return preMasterSecret, nil -} - -func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - return errors.New("tls: unexpected ServerKeyExchange") -} - -func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - preMasterSecret := make([]byte, 48) - preMasterSecret[0] = byte(clientHello.vers >> 8) - preMasterSecret[1] = byte(clientHello.vers) - _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) - if err != nil { - return nil, nil, err - } - - rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") - } - encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) - if err != nil { - return nil, nil, err - } - ckx := new(clientKeyExchangeMsg) - ckx.ciphertext = make([]byte, len(encrypted)+2) - ckx.ciphertext[0] = byte(len(encrypted) >> 8) - ckx.ciphertext[1] = byte(len(encrypted)) - copy(ckx.ciphertext[2:], encrypted) - return preMasterSecret, ckx, nil -} - -// sha1Hash calculates a SHA1 hash over the given byte slices. -func sha1Hash(slices [][]byte) []byte { - hsha1 := sha1.New() - for _, slice := range slices { - hsha1.Write(slice) - } - return hsha1.Sum(nil) -} - -// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the -// concatenation of an MD5 and SHA1 hash. -func md5SHA1Hash(slices [][]byte) []byte { - md5sha1 := make([]byte, md5.Size+sha1.Size) - hmd5 := md5.New() - for _, slice := range slices { - hmd5.Write(slice) - } - copy(md5sha1, hmd5.Sum(nil)) - copy(md5sha1[md5.Size:], sha1Hash(slices)) - return md5sha1 -} - -// hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for >= TLS 1.2) or using a default based on -// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't -// do pre-hashing, it returns the concatenation of the slices. -func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { - var signed []byte - for _, slice := range slices { - signed = append(signed, slice...) - } - return signed - } - if version >= VersionTLS12 { - h := hashFunc.New() - for _, slice := range slices { - h.Write(slice) - } - digest := h.Sum(nil) - return digest - } - if sigType == signatureECDSA { - return sha1Hash(slices) - } - return md5SHA1Hash(slices) -} - -// ecdheKeyAgreement implements a TLS key agreement where the server -// generates an ephemeral EC public/private key pair and signs it. The -// pre-master secret is then calculated using ECDH. The signature may -// be ECDSA, Ed25519 or RSA. -type ecdheKeyAgreement struct { - version uint16 - isRSA bool - params ecdheParameters - - // ckx and preMasterSecret are generated in processServerKeyExchange - // and returned in generateClientKeyExchange. - ckx *clientKeyExchangeMsg - preMasterSecret []byte -} - -func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - var curveID CurveID - for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { - curveID = c - break - } - } - - if curveID == 0 { - return nil, errors.New("tls: no supported elliptic curves offered") - } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - - params, err := generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, err - } - ka.params = params - - // See RFC 4492, Section 5.4. - ecdhePublic := params.PublicKey() - serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) - serverECDHEParams[0] = 3 // named curve - serverECDHEParams[1] = byte(curveID >> 8) - serverECDHEParams[2] = byte(curveID) - serverECDHEParams[3] = byte(len(ecdhePublic)) - copy(serverECDHEParams[4:], ecdhePublic) - - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) - } - - var signatureAlgorithm SignatureScheme - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) - if err != nil { - return nil, err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return nil, err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) - if err != nil { - return nil, err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") - } - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) - - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := priv.Sign(config.rand(), signed, signOpts) - if err != nil { - return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) - } - - skx := new(serverKeyExchangeMsg) - sigAndHashLen := 0 - if ka.version >= VersionTLS12 { - sigAndHashLen = 2 - } - skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) - copy(skx.key, serverECDHEParams) - k := skx.key[len(serverECDHEParams):] - if ka.version >= VersionTLS12 { - k[0] = byte(signatureAlgorithm >> 8) - k[1] = byte(signatureAlgorithm) - k = k[2:] - } - k[0] = byte(len(sig) >> 8) - k[1] = byte(len(sig)) - copy(k[2:], sig) - - return skx, nil -} - -func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { - return nil, errClientKeyExchange - } - - preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) - if preMasterSecret == nil { - return nil, errClientKeyExchange - } - - return preMasterSecret, nil -} - -func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - if len(skx.key) < 4 { - return errServerKeyExchange - } - if skx.key[0] != 3 { // named curve - return errors.New("tls: server selected unsupported curve") - } - curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) - - publicLen := int(skx.key[3]) - if publicLen+4 > len(skx.key) { - return errServerKeyExchange - } - serverECDHEParams := skx.key[:4+publicLen] - publicKey := serverECDHEParams[4:] - - sig := skx.key[4+publicLen:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return errors.New("tls: server selected unsupported curve") - } - - params, err := generateECDHEParameters(config.rand(), curveID) - if err != nil { - return err - } - ka.params = params - - ka.preMasterSecret = params.SharedKey(publicKey) - if ka.preMasterSecret == nil { - return errServerKeyExchange - } - - ourPublicKey := params.PublicKey() - ka.ckx = new(clientKeyExchangeMsg) - ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) - ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) - copy(ka.ckx.ciphertext[1:], ourPublicKey) - - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - sig = sig[2:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) - if err != nil { - return err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return errServerKeyExchange - } - - sigLen := int(sig[0])<<8 | int(sig[1]) - if sigLen+2 != len(sig) { - return errServerKeyExchange - } - sig = sig[2:] - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) - if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - return nil -} - -func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - if ka.ckx == nil { - return nil, nil, errors.New("tls: missing ServerKeyExchange message") - } - - return ka.preMasterSecret, ka.ckx, nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go deleted file mode 100644 index da13904a6e8..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/key_schedule.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto/elliptic" - "crypto/hmac" - "errors" - "hash" - "io" - "math/big" - - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" -) - -// This file contains the functions necessary to compute the TLS 1.3 key -// schedule. See RFC 8446, Section 7. - -const ( - resumptionBinderLabel = "res binder" - clientHandshakeTrafficLabel = "c hs traffic" - serverHandshakeTrafficLabel = "s hs traffic" - clientApplicationTrafficLabel = "c ap traffic" - serverApplicationTrafficLabel = "s ap traffic" - exporterLabel = "exp master" - resumptionLabel = "res master" - trafficUpdateLabel = "traffic upd" -) - -// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { - var hkdfLabel cryptobyte.Builder - hkdfLabel.AddUint16(uint16(length)) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte("tls13 ")) - b.AddBytes([]byte(label)) - }) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(context) - }) - out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) - if err != nil || n != length { - panic("tls: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} - -// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - if transcript == nil { - transcript = c.hash.New() - } - return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) -} - -// extract implements HKDF-Extract with the cipher suite hash. -func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { - if newSecret == nil { - newSecret = make([]byte, c.hash.Size()) - } - return hkdf.Extract(c.hash.New, newSecret, currentSecret) -} - -// nextTrafficSecret generates the next traffic secret, given the current one, -// according to RFC 8446, Section 7.2. -func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { - return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) -} - -// trafficKey generates traffic keys according to RFC 8446, Section 7.3. -func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { - key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) - iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) - return -} - -// finishedHash generates the Finished verify_data or PskBinderEntry according -// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey -// selection. -func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { - finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) - verifyData := hmac.New(c.hash.New, finishedKey) - verifyData.Write(transcript.Sum(nil)) - return verifyData.Sum(nil) -} - -// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to -// RFC 8446, Section 7.5. -func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { - expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) - return func(label string, context []byte, length int) ([]byte, error) { - secret := c.deriveSecret(expMasterSecret, label, nil) - h := c.hash.New() - h.Write(context) - return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil - } -} - -// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, -// according to RFC 8446, Section 4.2.8.2. -type ecdheParameters interface { - CurveID() CurveID - PublicKey() []byte - SharedKey(peerPublicKey []byte) []byte -} - -func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { - if curveID == X25519 { - privateKey := make([]byte, curve25519.ScalarSize) - if _, err := io.ReadFull(rand, privateKey); err != nil { - return nil, err - } - publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) - if err != nil { - return nil, err - } - return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil - } - - curve, ok := curveForCurveID(curveID) - if !ok { - return nil, errors.New("tls: internal error: unsupported curve") - } - - p := &nistParameters{curveID: curveID} - var err error - p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) - if err != nil { - return nil, err - } - return p, nil -} - -func curveForCurveID(id CurveID) (elliptic.Curve, bool) { - switch id { - case CurveP256: - return elliptic.P256(), true - case CurveP384: - return elliptic.P384(), true - case CurveP521: - return elliptic.P521(), true - default: - return nil, false - } -} - -type nistParameters struct { - privateKey []byte - x, y *big.Int // public key - curveID CurveID -} - -func (p *nistParameters) CurveID() CurveID { - return p.curveID -} - -func (p *nistParameters) PublicKey() []byte { - curve, _ := curveForCurveID(p.curveID) - return elliptic.Marshal(curve, p.x, p.y) -} - -func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { - curve, _ := curveForCurveID(p.curveID) - // Unmarshal also checks whether the given point is on the curve. - x, y := elliptic.Unmarshal(curve, peerPublicKey) - if x == nil { - return nil - } - - xShared, _ := curve.ScalarMult(x, y, p.privateKey) - sharedKey := make([]byte, (curve.Params().BitSize+7)/8) - return xShared.FillBytes(sharedKey) -} - -type x25519Parameters struct { - privateKey []byte - publicKey []byte -} - -func (p *x25519Parameters) CurveID() CurveID { - return X25519 -} - -func (p *x25519Parameters) PublicKey() []byte { - return p.publicKey[:] -} - -func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { - sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) - if err != nil { - return nil - } - return sharedKey -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go deleted file mode 100644 index 006b8c1de9f..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/ticket.go +++ /dev/null @@ -1,259 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "errors" - "io" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -// sessionState contains the information that is serialized into a session -// ticket in order to later resume a connection. -type sessionState struct { - vers uint16 - cipherSuite uint16 - createdAt uint64 - masterSecret []byte // opaque master_secret<1..2^16-1>; - // struct { opaque certificate<1..2^24-1> } Certificate; - certificates [][]byte // Certificate certificate_list<0..2^24-1>; - - // usedOldKey is true if the ticket from which this session came from - // was encrypted with an older key and thus should be refreshed. - usedOldKey bool -} - -func (m *sessionState) marshal() []byte { - var b cryptobyte.Builder - b.AddUint16(m.vers) - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.masterSecret) - }) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for _, cert := range m.certificates { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - } - }) - return b.BytesOrPanic() -} - -func (m *sessionState) unmarshal(data []byte) bool { - *m = sessionState{usedOldKey: m.usedOldKey} - s := cryptobyte.String(data) - if ok := s.ReadUint16(&m.vers) && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint16LengthPrefixed(&s, &m.masterSecret) && - len(m.masterSecret) != 0; !ok { - return false - } - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - if !readUint24LengthPrefixed(&certList, &cert) { - return false - } - m.certificates = append(m.certificates, cert) - } - return s.Empty() -} - -// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first -// version (revision = 0) doesn't carry any of the information needed for 0-RTT -// validation and the nonce is always empty. -// version (revision = 1) carries the max_early_data_size sent in the ticket. -// version (revision = 2) carries the ALPN sent in the ticket. -type sessionStateTLS13 struct { - // uint8 version = 0x0304; - // uint8 revision = 2; - cipherSuite uint16 - createdAt uint64 - resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; - maxEarlyData uint32 - alpn string - - appData []byte -} - -func (m *sessionStateTLS13) marshal() []byte { - var b cryptobyte.Builder - b.AddUint16(VersionTLS13) - b.AddUint8(2) // revision - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.resumptionSecret) - }) - marshalCertificate(&b, m.certificate) - b.AddUint32(m.maxEarlyData) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpn)) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.appData) - }) - return b.BytesOrPanic() -} - -func (m *sessionStateTLS13) unmarshal(data []byte) bool { - *m = sessionStateTLS13{} - s := cryptobyte.String(data) - var version uint16 - var revision uint8 - var alpn []byte - ret := s.ReadUint16(&version) && - version == VersionTLS13 && - s.ReadUint8(&revision) && - revision == 2 && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint8LengthPrefixed(&s, &m.resumptionSecret) && - len(m.resumptionSecret) != 0 && - unmarshalCertificate(&s, &m.certificate) && - s.ReadUint32(&m.maxEarlyData) && - readUint8LengthPrefixed(&s, &alpn) && - readUint16LengthPrefixed(&s, &m.appData) && - s.Empty() - m.alpn = string(alpn) - return ret -} - -func (c *Conn) encryptTicket(state []byte) ([]byte, error) { - if len(c.ticketKeys) == 0 { - return nil, errors.New("tls: internal error: session ticket keys unavailable") - } - - encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - - if _, err := io.ReadFull(c.config.rand(), iv); err != nil { - return nil, err - } - key := c.ticketKeys[0] - copy(keyName, key.keyName[:]) - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) - } - cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - mac.Sum(macBytes[:0]) - - return encrypted, nil -} - -func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { - if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { - return nil, false - } - - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] - - keyIndex := -1 - for i, candidateKey := range c.ticketKeys { - if bytes.Equal(keyName, candidateKey.keyName[:]) { - keyIndex = i - break - } - } - if keyIndex == -1 { - return nil, false - } - key := &c.ticketKeys[keyIndex] - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - expected := mac.Sum(nil) - - if subtle.ConstantTimeCompare(macBytes, expected) != 1 { - return nil, false - } - - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, false - } - plaintext = make([]byte, len(ciphertext)) - cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) - - return plaintext, keyIndex > 0 -} - -func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { - m := new(newSessionTicketMsgTLS13) - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionStateTLS13{ - cipherSuite: c.cipherSuite, - createdAt: uint64(c.config.time().Unix()), - resumptionSecret: c.resumptionSecret, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, - appData: appData, - alpn: c.clientProtocol, - } - if c.extraConfig != nil { - state.maxEarlyData = c.extraConfig.MaxEarlyData - } - var err error - m.label, err = c.encryptTicket(state.marshal()) - if err != nil { - return nil, err - } - m.lifetime = uint32(maxSessionTicketLifetime / time.Second) - if c.extraConfig != nil { - m.maxEarlyData = c.extraConfig.MaxEarlyData - } - return m, nil -} - -// GetSessionTicket generates a new session ticket. -// It should only be called after the handshake completes. -// It can only be used for servers, and only if the alternative record layer is set. -// The ticket may be nil if config.SessionTicketsDisabled is set, -// or if the client isn't able to receive session tickets. -func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { - if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { - return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") - } - if c.config.SessionTicketsDisabled { - return nil, nil - } - - m, err := c.getSessionTicketMsg(appData) - if err != nil { - return nil, err - } - return m.marshal(), nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/tls.go b/vendor/github.com/marten-seemann/qtls-go1-16/tls.go deleted file mode 100644 index 10cbae03b4f..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-16/tls.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// package qtls partially implements TLS 1.2, as specified in RFC 5246, -// and TLS 1.3, as specified in RFC 8446. -package qtls - -// BUG(agl): The crypto/tls package only implements some countermeasures -// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 -// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "net" - "os" - "strings" - "time" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - extraConfig: extraConfig, - } - c.handshakeFn = c.serverHandshake - return c -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - extraConfig: extraConfig, - isClient: true, - } - c.handshakeFn = c.clientHandshake - return c -} - -// A listener implements a network listener (net.Listener) for TLS connections. -type listener struct { - net.Listener - config *Config - extraConfig *ExtraConfig -} - -// Accept waits for and returns the next incoming TLS connection. -// The returned connection is of type *Conn. -func (l *listener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return Server(c, l.config, l.extraConfig), nil -} - -// NewListener creates a Listener which accepts connections from an inner -// Listener and wraps each connection with Server. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { - l := new(listener) - l.Listener = inner - l.config = config - l.extraConfig = extraConfig - return l -} - -// Listen creates a TLS listener accepting connections on the -// given network address using net.Listen. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { - if config == nil || len(config.Certificates) == 0 && - config.GetCertificate == nil && config.GetConfigForClient == nil { - return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, config, extraConfig), nil -} - -type timeoutError struct{} - -func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -// DialWithDialer connects to the given network address using dialer.Dial and -// then initiates a TLS handshake, returning the resulting TLS connection. Any -// timeout or deadline given in the dialer apply to connection and TLS -// handshake as a whole. -// -// DialWithDialer interprets a nil configuration as equivalent to the zero -// configuration; see the documentation of Config for the defaults. -func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - return dial(context.Background(), dialer, network, addr, config, extraConfig) -} - -func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - // We want the Timeout and Deadline values from dialer to cover the - // whole process: TCP connection and TLS handshake. This means that we - // also need to start our own timers now. - timeout := netDialer.Timeout - - if !netDialer.Deadline.IsZero() { - deadlineTimeout := time.Until(netDialer.Deadline) - if timeout == 0 || deadlineTimeout < timeout { - timeout = deadlineTimeout - } - } - - // hsErrCh is non-nil if we might not wait for Handshake to complete. - var hsErrCh chan error - if timeout != 0 || ctx.Done() != nil { - hsErrCh = make(chan error, 2) - } - if timeout != 0 { - timer := time.AfterFunc(timeout, func() { - hsErrCh <- timeoutError{} - }) - defer timer.Stop() - } - - rawConn, err := netDialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if config == nil { - config = defaultConfig() - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - c := config.Clone() - c.ServerName = hostname - config = c - } - - conn := Client(rawConn, config, extraConfig) - - if hsErrCh == nil { - err = conn.Handshake() - } else { - go func() { - hsErrCh <- conn.Handshake() - }() - - select { - case <-ctx.Done(): - err = ctx.Err() - case err = <-hsErrCh: - if err != nil { - // If the error was due to the context - // closing, prefer the context's error, rather - // than some random network teardown error. - if e := ctx.Err(); e != nil { - err = e - } - } - } - } - - if err != nil { - rawConn.Close() - return nil, err - } - - return conn, nil -} - -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) -} - -// Dialer dials TLS connections given a configuration and a Dialer for the -// underlying connection. -type Dialer struct { - // NetDialer is the optional dialer to use for the TLS connections' - // underlying TCP connections. - // A nil NetDialer is equivalent to the net.Dialer zero value. - NetDialer *net.Dialer - - // Config is the TLS configuration to use for new connections. - // A nil configuration is equivalent to the zero - // configuration; see the documentation of Config for the - // defaults. - Config *Config - - ExtraConfig *ExtraConfig -} - -// Dial connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The returned Conn, if any, will always be of type *Conn. -func (d *Dialer) Dial(network, addr string) (net.Conn, error) { - return d.DialContext(context.Background(), network, addr) -} - -func (d *Dialer) netDialer() *net.Dialer { - if d.NetDialer != nil { - return d.NetDialer - } - return new(net.Dialer) -} - -// DialContext connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The provided Context must be non-nil. If the context expires before -// the connection is complete, an error is returned. Once successfully -// connected, any expiration of the context will not affect the -// connection. -// -// The returned Conn, if any, will always be of type *Conn. -func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) - if err != nil { - // Don't return c (a typed nil) in an interface. - return nil, err - } - return c, nil -} - -// LoadX509KeyPair reads and parses a public/private key pair from a pair -// of files. The files must contain PEM encoded data. The certificate file -// may contain intermediate certificates following the leaf certificate to -// form a certificate chain. On successful return, Certificate.Leaf will -// be nil because the parsed form of the certificate is not retained. -func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := os.ReadFile(certFile) - if err != nil { - return Certificate{}, err - } - keyPEMBlock, err := os.ReadFile(keyFile) - if err != nil { - return Certificate{}, err - } - return X509KeyPair(certPEMBlock, keyPEMBlock) -} - -// X509KeyPair parses a public/private key pair from a pair of -// PEM encoded data. On successful return, Certificate.Leaf will be nil because -// the parsed form of the certificate is not retained. -func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { - fail := func(err error) (Certificate, error) { return Certificate{}, err } - - var cert Certificate - var skippedBlockTypes []string - for { - var certDERBlock *pem.Block - certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) - if certDERBlock == nil { - break - } - if certDERBlock.Type == "CERTIFICATE" { - cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) - } else { - skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) - } - } - - if len(cert.Certificate) == 0 { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in certificate input")) - } - if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { - return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) - } - return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - - skippedBlockTypes = skippedBlockTypes[:0] - var keyDERBlock *pem.Block - for { - keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) - if keyDERBlock == nil { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in key input")) - } - if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { - return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) - } - return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { - break - } - skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) - } - - // We don't need to parse the public key for TLS, but we so do anyway - // to check that it looks sane and matches the private key. - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return fail(err) - } - - cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) - if err != nil { - return fail(err) - } - - switch pub := x509Cert.PublicKey.(type) { - case *rsa.PublicKey: - priv, ok := cert.PrivateKey.(*rsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.N.Cmp(priv.N) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case *ecdsa.PublicKey: - priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case ed25519.PublicKey: - priv, ok := cert.PrivateKey.(ed25519.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { - return fail(errors.New("tls: private key does not match public key")) - } - default: - return fail(errors.New("tls: unknown public key algorithm")) - } - - return cert, nil -} - -// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. -// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. -func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { - if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: - return key, nil - default: - return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") - } - } - if key, err := x509.ParseECPrivateKey(der); err == nil { - return key, nil - } - - return nil, errors.New("tls: failed to parse private key") -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/README.md b/vendor/github.com/marten-seemann/qtls-go1-17/README.md deleted file mode 100644 index 3e902212721..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# qtls - -[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-17.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-17) -[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml) - -This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/auth.go b/vendor/github.com/marten-seemann/qtls-go1-17/auth.go deleted file mode 100644 index 1ef675fd37c..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/auth.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rsa" - "errors" - "fmt" - "hash" - "io" -) - -// verifyHandshakeSignature verifies a signature against pre-hashed -// (if required) handshake contents. -func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { - switch sigType { - case signatureECDSA: - pubKey, ok := pubkey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) - } - if !ecdsa.VerifyASN1(pubKey, signed, sig) { - return errors.New("ECDSA verification failure") - } - case signatureEd25519: - pubKey, ok := pubkey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey) - } - if !ed25519.Verify(pubKey, signed, sig) { - return errors.New("Ed25519 verification failure") - } - case signaturePKCS1v15: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil { - return err - } - case signatureRSAPSS: - pubKey, ok := pubkey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("expected an RSA public key, got %T", pubkey) - } - signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} - if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil { - return err - } - default: - return errors.New("internal error: unknown signature type") - } - return nil -} - -const ( - serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" - clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" -) - -var signaturePadding = []byte{ - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, - 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, -} - -// signedMessage returns the pre-hashed (if necessary) message to be signed by -// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. -func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { - if sigHash == directSigning { - b := &bytes.Buffer{} - b.Write(signaturePadding) - io.WriteString(b, context) - b.Write(transcript.Sum(nil)) - return b.Bytes() - } - h := sigHash.New() - h.Write(signaturePadding) - io.WriteString(h, context) - h.Write(transcript.Sum(nil)) - return h.Sum(nil) -} - -// typeAndHashFromSignatureScheme returns the corresponding signature type and -// crypto.Hash for a given TLS SignatureScheme. -func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) { - switch signatureAlgorithm { - case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512: - sigType = signaturePKCS1v15 - case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512: - sigType = signatureRSAPSS - case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512: - sigType = signatureECDSA - case Ed25519: - sigType = signatureEd25519 - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - switch signatureAlgorithm { - case PKCS1WithSHA1, ECDSAWithSHA1: - hash = crypto.SHA1 - case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256: - hash = crypto.SHA256 - case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384: - hash = crypto.SHA384 - case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512: - hash = crypto.SHA512 - case Ed25519: - hash = directSigning - default: - return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm) - } - return sigType, hash, nil -} - -// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for -// a given public key used with TLS 1.0 and 1.1, before the introduction of -// signature algorithm negotiation. -func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) { - switch pub.(type) { - case *rsa.PublicKey: - return signaturePKCS1v15, crypto.MD5SHA1, nil - case *ecdsa.PublicKey: - return signatureECDSA, crypto.SHA1, nil - case ed25519.PublicKey: - // RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1, - // but it requires holding on to a handshake transcript to do a - // full signature, and not even OpenSSL bothers with the - // complexity, so we can't even test it properly. - return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2") - default: - return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub) - } -} - -var rsaSignatureSchemes = []struct { - scheme SignatureScheme - minModulusBytes int - maxVersion uint16 -}{ - // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires - // emLen >= hLen + sLen + 2 - {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, - // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires - // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. - {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, - {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, - {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, - {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, -} - -// signatureSchemesForCertificate returns the list of supported SignatureSchemes -// for a given certificate, based on the public key and the protocol version, -// and optionally filtered by its explicit SupportedSignatureAlgorithms. -// -// This function must be kept in sync with supportedSignatureAlgorithms. -func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil - } - - var sigAlgs []SignatureScheme - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - if version != VersionTLS13 { - // In TLS 1.2 and earlier, ECDSA algorithms are not - // constrained to a single curve. - sigAlgs = []SignatureScheme{ - ECDSAWithP256AndSHA256, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - ECDSAWithSHA1, - } - break - } - switch pub.Curve { - case elliptic.P256(): - sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} - case elliptic.P384(): - sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} - case elliptic.P521(): - sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} - default: - return nil - } - case *rsa.PublicKey: - size := pub.Size() - sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) - for _, candidate := range rsaSignatureSchemes { - if size >= candidate.minModulusBytes && version <= candidate.maxVersion { - sigAlgs = append(sigAlgs, candidate.scheme) - } - } - case ed25519.PublicKey: - sigAlgs = []SignatureScheme{Ed25519} - default: - return nil - } - - if cert.SupportedSignatureAlgorithms != nil { - var filteredSigAlgs []SignatureScheme - for _, sigAlg := range sigAlgs { - if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) { - filteredSigAlgs = append(filteredSigAlgs, sigAlg) - } - } - return filteredSigAlgs - } - return sigAlgs -} - -// selectSignatureScheme picks a SignatureScheme from the peer's preference list -// that works with the selected certificate. It's only called for protocol -// versions that support signature algorithms, so TLS 1.2 and 1.3. -func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { - supportedAlgs := signatureSchemesForCertificate(vers, c) - if len(supportedAlgs) == 0 { - return 0, unsupportedCertificateError(c) - } - if len(peerAlgs) == 0 && vers == VersionTLS12 { - // For TLS 1.2, if the client didn't send signature_algorithms then we - // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1. - peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1} - } - // Pick signature scheme in the peer's preference order, as our - // preference order is not configurable. - for _, preferredAlg := range peerAlgs { - if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { - return preferredAlg, nil - } - } - return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms") -} - -// unsupportedCertificateError returns a helpful error for certificates with -// an unsupported private key. -func unsupportedCertificateError(cert *Certificate) error { - switch cert.PrivateKey.(type) { - case rsa.PrivateKey, ecdsa.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T", - cert.PrivateKey, cert.PrivateKey) - case *ed25519.PrivateKey: - return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey") - } - - signer, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer", - cert.PrivateKey) - } - - switch pub := signer.Public().(type) { - case *ecdsa.PublicKey: - switch pub.Curve { - case elliptic.P256(): - case elliptic.P384(): - case elliptic.P521(): - default: - return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name) - } - case *rsa.PublicKey: - return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms") - case ed25519.PublicKey: - default: - return fmt.Errorf("tls: unsupported certificate key (%T)", pub) - } - - if cert.SupportedSignatureAlgorithms != nil { - return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms") - } - - return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey) -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go b/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go deleted file mode 100644 index 53a3956aaeb..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/cipher_suites.go +++ /dev/null @@ -1,691 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/des" - "crypto/hmac" - "crypto/rc4" - "crypto/sha1" - "crypto/sha256" - "fmt" - "hash" - - "golang.org/x/crypto/chacha20poly1305" -) - -// CipherSuite is a TLS cipher suite. Note that most functions in this package -// accept and expose cipher suite IDs instead of this type. -type CipherSuite struct { - ID uint16 - Name string - - // Supported versions is the list of TLS protocol versions that can - // negotiate this cipher suite. - SupportedVersions []uint16 - - // Insecure is true if the cipher suite has known security issues - // due to its primitives, design, or implementation. - Insecure bool -} - -var ( - supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12} - supportedOnlyTLS12 = []uint16{VersionTLS12} - supportedOnlyTLS13 = []uint16{VersionTLS13} -) - -// CipherSuites returns a list of cipher suites currently implemented by this -// package, excluding those with security issues, which are returned by -// InsecureCipherSuites. -// -// The list is sorted by ID. Note that the default cipher suites selected by -// this package might depend on logic that can't be captured by a static list, -// and might not match those returned by this function. -func CipherSuites() []*CipherSuite { - return []*CipherSuite{ - {TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - - {TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false}, - {TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false}, - {TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false}, - - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false}, - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false}, - } -} - -// InsecureCipherSuites returns a list of cipher suites currently implemented by -// this package and which have security issues. -// -// Most applications should not use the cipher suites in this list, and should -// only use those returned by CipherSuites. -func InsecureCipherSuites() []*CipherSuite { - // This list includes RC4, CBC_SHA256, and 3DES cipher suites. See - // cipherSuitesPreferenceOrder for details. - return []*CipherSuite{ - {TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true}, - } -} - -// CipherSuiteName returns the standard name for the passed cipher suite ID -// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation -// of the ID value if the cipher suite is not implemented by this package. -func CipherSuiteName(id uint16) string { - for _, c := range CipherSuites() { - if c.ID == id { - return c.Name - } - } - for _, c := range InsecureCipherSuites() { - if c.ID == id { - return c.Name - } - } - return fmt.Sprintf("0x%04X", id) -} - -const ( - // suiteECDHE indicates that the cipher suite involves elliptic curve - // Diffie-Hellman. This means that it should only be selected when the - // client indicates that it supports ECC with a curve and point format - // that we're happy with. - suiteECDHE = 1 << iota - // suiteECSign indicates that the cipher suite involves an ECDSA or - // EdDSA signature and therefore may only be selected when the server's - // certificate is ECDSA or EdDSA. If this is not set then the cipher suite - // is RSA based. - suiteECSign - // suiteTLS12 indicates that the cipher suite should only be advertised - // and accepted when using TLS 1.2. - suiteTLS12 - // suiteSHA384 indicates that the cipher suite uses SHA384 as the - // handshake hash. - suiteSHA384 -) - -// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange -// mechanism, as well as the cipher+MAC pair or the AEAD. -type cipherSuite struct { - id uint16 - // the lengths, in bytes, of the key material needed for each component. - keyLen int - macLen int - ivLen int - ka func(version uint16) keyAgreement - // flags is a bitmask of the suite* values, above. - flags int - cipher func(key, iv []byte, isRead bool) interface{} - mac func(key []byte) hash.Hash - aead func(key, fixedNonce []byte) aead -} - -var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, - {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, - {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, - {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil}, - {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil}, -} - -// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which -// is also in supportedIDs and passes the ok filter. -func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite { - for _, id := range ids { - candidate := cipherSuiteByID(id) - if candidate == nil || !ok(candidate) { - continue - } - - for _, suppID := range supportedIDs { - if id == suppID { - return candidate - } - } - } - return nil -} - -// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash -// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. -type cipherSuiteTLS13 struct { - id uint16 - keyLen int - aead func(key, fixedNonce []byte) aead - hash crypto.Hash -} - -type CipherSuiteTLS13 struct { - ID uint16 - KeyLen int - Hash crypto.Hash - AEAD func(key, fixedNonce []byte) cipher.AEAD -} - -func (c *CipherSuiteTLS13) IVLen() int { - return aeadNonceLength -} - -var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. - {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, - {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, - {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, -} - -// cipherSuitesPreferenceOrder is the order in which we'll select (on the -// server) or advertise (on the client) TLS 1.0–1.2 cipher suites. -// -// Cipher suites are filtered but not reordered based on the application and -// peer's preferences, meaning we'll never select a suite lower in this list if -// any higher one is available. This makes it more defensible to keep weaker -// cipher suites enabled, especially on the server side where we get the last -// word, since there are no known downgrade attacks on cipher suites selection. -// -// The list is sorted by applying the following priority rules, stopping at the -// first (most important) applicable one: -// -// - Anything else comes before RC4 -// -// RC4 has practically exploitable biases. See https://www.rc4nomore.com. -// -// - Anything else comes before CBC_SHA256 -// -// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 -// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. -// -// - Anything else comes before 3DES -// -// 3DES has 64-bit blocks, which makes it fundamentally susceptible to -// birthday attacks. See https://sweet32.info. -// -// - ECDHE comes before anything else -// -// Once we got the broken stuff out of the way, the most important -// property a cipher suite can have is forward secrecy. We don't -// implement FFDHE, so that means ECDHE. -// -// - AEADs come before CBC ciphers -// -// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites -// are fundamentally fragile, and suffered from an endless sequence of -// padding oracle attacks. See https://eprint.iacr.org/2015/1129, -// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and -// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. -// -// - AES comes before ChaCha20 -// -// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster -// than ChaCha20Poly1305. -// -// When AES hardware is not available, AES-128-GCM is one or more of: much -// slower, way more complex, and less safe (because not constant time) -// than ChaCha20Poly1305. -// -// We use this list if we think both peers have AES hardware, and -// cipherSuitesPreferenceOrderNoAES otherwise. -// -// - AES-128 comes before AES-256 -// -// The only potential advantages of AES-256 are better multi-target -// margins, and hypothetical post-quantum properties. Neither apply to -// TLS, and AES-256 is slower due to its four extra rounds (which don't -// contribute to the advantages above). -// -// - ECDSA comes before RSA -// -// The relative order of ECDSA and RSA cipher suites doesn't matter, -// as they depend on the certificate. Pick one to get a stable order. -// -var cipherSuitesPreferenceOrder = []uint16{ - // AEADs w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // CBC w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - - // AEADs w/o ECDHE - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - - // CBC w/o ECDHE - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - - // 3DES - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var cipherSuitesPreferenceOrderNoAES = []uint16{ - // ChaCha20Poly1305 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - - // AES-GCM w/ ECDHE - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - - // The rest of cipherSuitesPreferenceOrder. - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - TLS_RSA_WITH_AES_128_GCM_SHA256, - TLS_RSA_WITH_AES_256_GCM_SHA384, - TLS_RSA_WITH_AES_128_CBC_SHA, - TLS_RSA_WITH_AES_256_CBC_SHA, - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_RSA_WITH_3DES_EDE_CBC_SHA, - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -// disabledCipherSuites are not used unless explicitly listed in -// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder. -var disabledCipherSuites = []uint16{ - // CBC_SHA256 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, - TLS_RSA_WITH_AES_128_CBC_SHA256, - - // RC4 - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, - TLS_RSA_WITH_RC4_128_SHA, -} - -var ( - defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites) - defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen] -) - -// defaultCipherSuitesTLS13 is also the preference order, since there are no -// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as -// cipherSuitesPreferenceOrder applies. -var defaultCipherSuitesTLS13 = []uint16{ - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, - TLS_CHACHA20_POLY1305_SHA256, -} - -var defaultCipherSuitesTLS13NoAES = []uint16{ - TLS_CHACHA20_POLY1305_SHA256, - TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, -} - -var aesgcmCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true, - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true, - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true, - // TLS 1.3 - TLS_AES_128_GCM_SHA256: true, - TLS_AES_256_GCM_SHA384: true, -} - -var nonAESGCMAEADCiphers = map[uint16]bool{ - // TLS 1.2 - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true, - // TLS 1.3 - TLS_CHACHA20_POLY1305_SHA256: true, -} - -// aesgcmPreferred returns whether the first known cipher in the preference list -// is an AES-GCM cipher, implying the peer has hardware support for it. -func aesgcmPreferred(ciphers []uint16) bool { - for _, cID := range ciphers { - if c := cipherSuiteByID(cID); c != nil { - return aesgcmCiphers[cID] - } - if c := cipherSuiteTLS13ByID(cID); c != nil { - return aesgcmCiphers[cID] - } - } - return false -} - -func cipherRC4(key, iv []byte, isRead bool) interface{} { - cipher, _ := rc4.NewCipher(key) - return cipher -} - -func cipher3DES(key, iv []byte, isRead bool) interface{} { - block, _ := des.NewTripleDESCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -func cipherAES(key, iv []byte, isRead bool) interface{} { - block, _ := aes.NewCipher(key) - if isRead { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -// macSHA1 returns a SHA-1 based constant time MAC. -func macSHA1(key []byte) hash.Hash { - return hmac.New(newConstantTimeHash(sha1.New), key) -} - -// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and -// is currently only used in disabled-by-default cipher suites. -func macSHA256(key []byte) hash.Hash { - return hmac.New(sha256.New, key) -} - -type aead interface { - cipher.AEAD - - // explicitNonceLen returns the number of bytes of explicit nonce - // included in each record. This is eight for older AEADs and - // zero for modern ones. - explicitNonceLen() int -} - -const ( - aeadNonceLength = 12 - noncePrefixLength = 4 -) - -// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to -// each call. -type prefixNonceAEAD struct { - // nonce contains the fixed part of the nonce in the first four bytes. - nonce [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } -func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } - -func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - copy(f.nonce[4:], nonce) - return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) -} - -func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - copy(f.nonce[4:], nonce) - return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) -} - -// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce -// before each call. -type xorNonceAEAD struct { - nonceMask [aeadNonceLength]byte - aead cipher.AEAD -} - -func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number -func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } - -func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result -} - -func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) - for i, b := range nonce { - f.nonceMask[4+i] ^= b - } - - return result, err -} - -func aeadAESGCM(key, noncePrefix []byte) aead { - if len(noncePrefix) != noncePrefixLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &prefixNonceAEAD{aead: aead} - copy(ret.nonce[:], noncePrefix) - return ret -} - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return aeadAESGCMTLS13(key, fixedNonce) -} - -func aeadAESGCMTLS13(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aes, err := aes.NewCipher(key) - if err != nil { - panic(err) - } - aead, err := cipher.NewGCM(aes) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -func aeadChaCha20Poly1305(key, nonceMask []byte) aead { - if len(nonceMask) != aeadNonceLength { - panic("tls: internal error: wrong nonce length") - } - aead, err := chacha20poly1305.New(key) - if err != nil { - panic(err) - } - - ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], nonceMask) - return ret -} - -type constantTimeHash interface { - hash.Hash - ConstantTimeSum(b []byte) []byte -} - -// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces -// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. -type cthWrapper struct { - h constantTimeHash -} - -func (c *cthWrapper) Size() int { return c.h.Size() } -func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } -func (c *cthWrapper) Reset() { c.h.Reset() } -func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } -func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } - -func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { - return func() hash.Hash { - return &cthWrapper{h().(constantTimeHash)} - } -} - -// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3. -func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte { - h.Reset() - h.Write(seq) - h.Write(header) - h.Write(data) - res := h.Sum(out) - if extra != nil { - h.Write(extra) - } - return res -} - -func rsaKA(version uint16) keyAgreement { - return rsaKeyAgreement{} -} - -func ecdheECDSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: false, - version: version, - } -} - -func ecdheRSAKA(version uint16) keyAgreement { - return &ecdheKeyAgreement{ - isRSA: true, - version: version, - } -} - -// mutualCipherSuite returns a cipherSuite given a list of supported -// ciphersuites and the id requested by the peer. -func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { - for _, id := range have { - if id == want { - return cipherSuiteByID(id) - } - } - return nil -} - -func cipherSuiteByID(id uint16) *cipherSuite { - for _, cipherSuite := range cipherSuites { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 { - for _, id := range have { - if id == want { - return cipherSuiteTLS13ByID(id) - } - } - return nil -} - -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 { - for _, cipherSuite := range cipherSuitesTLS13 { - if cipherSuite.id == id { - return cipherSuite - } - } - return nil -} - -// A list of cipher suite IDs that are, or have been, implemented by this -// package. -// -// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml -const ( - // TLS 1.0 - 1.2 cipher suites. - TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 - TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a - TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f - TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 - TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c - TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c - TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d - TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 - TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a - TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 - TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 - TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 - TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 - TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 - TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b - TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 - TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9 - - // TLS 1.3 cipher suites. - TLS_AES_128_GCM_SHA256 uint16 = 0x1301 - TLS_AES_256_GCM_SHA384 uint16 = 0x1302 - TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 - - // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator - // that the client is doing version fallback. See RFC 7507. - TLS_FALLBACK_SCSV uint16 = 0x5600 - - // Legacy names for the corresponding cipher suites with the correct _SHA256 - // suffix, retained for backward compatibility. - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/common.go b/vendor/github.com/marten-seemann/qtls-go1-17/common.go deleted file mode 100644 index 8a9b6804867..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/common.go +++ /dev/null @@ -1,1485 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "container/list" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/sha512" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "net" - "strings" - "sync" - "time" -) - -const ( - VersionTLS10 = 0x0301 - VersionTLS11 = 0x0302 - VersionTLS12 = 0x0303 - VersionTLS13 = 0x0304 - - // Deprecated: SSLv3 is cryptographically broken, and is no longer - // supported by this package. See golang.org/issue/32716. - VersionSSL30 = 0x0300 -) - -const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) - maxUselessRecords = 16 // maximum number of consecutive non-advancing records -) - -// TLS record types. -type recordType uint8 - -const ( - recordTypeChangeCipherSpec recordType = 20 - recordTypeAlert recordType = 21 - recordTypeHandshake recordType = 22 - recordTypeApplicationData recordType = 23 -) - -// TLS handshake message types. -const ( - typeHelloRequest uint8 = 0 - typeClientHello uint8 = 1 - typeServerHello uint8 = 2 - typeNewSessionTicket uint8 = 4 - typeEndOfEarlyData uint8 = 5 - typeEncryptedExtensions uint8 = 8 - typeCertificate uint8 = 11 - typeServerKeyExchange uint8 = 12 - typeCertificateRequest uint8 = 13 - typeServerHelloDone uint8 = 14 - typeCertificateVerify uint8 = 15 - typeClientKeyExchange uint8 = 16 - typeFinished uint8 = 20 - typeCertificateStatus uint8 = 22 - typeKeyUpdate uint8 = 24 - typeNextProtocol uint8 = 67 // Not IANA assigned - typeMessageHash uint8 = 254 // synthetic message -) - -// TLS compression types. -const ( - compressionNone uint8 = 0 -) - -type Extension struct { - Type uint16 - Data []byte -} - -// TLS extension numbers -const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 - extensionSessionTicket uint16 = 35 - extensionPreSharedKey uint16 = 41 - extensionEarlyData uint16 = 42 - extensionSupportedVersions uint16 = 43 - extensionCookie uint16 = 44 - extensionPSKModes uint16 = 45 - extensionCertificateAuthorities uint16 = 47 - extensionSignatureAlgorithmsCert uint16 = 50 - extensionKeyShare uint16 = 51 - extensionRenegotiationInfo uint16 = 0xff01 -) - -// TLS signaling cipher suite values -const ( - scsvRenegotiation uint16 = 0x00ff -) - -type EncryptionLevel uint8 - -const ( - EncryptionHandshake EncryptionLevel = iota - Encryption0RTT - EncryptionApplication -) - -// CurveID is a tls.CurveID -type CurveID = tls.CurveID - -const ( - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - X25519 CurveID = 29 -) - -// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. -type keyShare struct { - group CurveID - data []byte -} - -// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. -const ( - pskModePlain uint8 = 0 - pskModeDHE uint8 = 1 -) - -// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved -// session. See RFC 8446, Section 4.2.11. -type pskIdentity struct { - label []byte - obfuscatedTicketAge uint32 -} - -// TLS Elliptic Curve Point Formats -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 -const ( - pointFormatUncompressed uint8 = 0 -) - -// TLS CertificateStatusType (RFC 3546) -const ( - statusTypeOCSP uint8 = 1 -) - -// Certificate types (for certificateRequestMsg) -const ( - certTypeRSASign = 1 - certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3. -) - -// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with -// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do. -const ( - signaturePKCS1v15 uint8 = iota + 225 - signatureRSAPSS - signatureECDSA - signatureEd25519 -) - -// directSigning is a standard Hash value that signals that no pre-hashing -// should be performed, and that the input should be signed directly. It is the -// hash function associated with the Ed25519 signature scheme. -var directSigning crypto.Hash = 0 - -// supportedSignatureAlgorithms contains the signature and hash algorithms that -// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ -// CertificateRequest. The two fields are merged to match with TLS 1.3. -// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var supportedSignatureAlgorithms = []SignatureScheme{ - PSSWithSHA256, - ECDSAWithP256AndSHA256, - Ed25519, - PSSWithSHA384, - PSSWithSHA512, - PKCS1WithSHA256, - PKCS1WithSHA384, - PKCS1WithSHA512, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - PKCS1WithSHA1, - ECDSAWithSHA1, -} - -// helloRetryRequestRandom is set as the Random value of a ServerHello -// to signal that the message is actually a HelloRetryRequest. -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} - -const ( - // downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server - // random as a downgrade protection if the server would be capable of - // negotiating a higher version. See RFC 8446, Section 4.1.3. - downgradeCanaryTLS12 = "DOWNGRD\x01" - downgradeCanaryTLS11 = "DOWNGRD\x00" -) - -// testingOnlyForceDowngradeCanary is set in tests to force the server side to -// include downgrade canaries even if it's using its highers supported version. -var testingOnlyForceDowngradeCanary bool - -type ConnectionState = tls.ConnectionState - -// ConnectionState records basic TLS details about the connection. -type connectionState struct { - // Version is the TLS version used by the connection (e.g. VersionTLS12). - Version uint16 - - // HandshakeComplete is true if the handshake has concluded. - HandshakeComplete bool - - // DidResume is true if this connection was successfully resumed from a - // previous session with a session ticket or similar mechanism. - DidResume bool - - // CipherSuite is the cipher suite negotiated for the connection (e.g. - // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256). - CipherSuite uint16 - - // NegotiatedProtocol is the application protocol negotiated with ALPN. - NegotiatedProtocol string - - // NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation. - // - // Deprecated: this value is always true. - NegotiatedProtocolIsMutual bool - - // ServerName is the value of the Server Name Indication extension sent by - // the client. It's available both on the server and on the client side. - ServerName string - - // PeerCertificates are the parsed certificates sent by the peer, in the - // order in which they were sent. The first element is the leaf certificate - // that the connection is verified against. - // - // On the client side, it can't be empty. On the server side, it can be - // empty if Config.ClientAuth is not RequireAnyClientCert or - // RequireAndVerifyClientCert. - PeerCertificates []*x509.Certificate - - // VerifiedChains is a list of one or more chains where the first element is - // PeerCertificates[0] and the last element is from Config.RootCAs (on the - // client side) or Config.ClientCAs (on the server side). - // - // On the client side, it's set if Config.InsecureSkipVerify is false. On - // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven - // (and the peer provided a certificate) or RequireAndVerifyClientCert. - VerifiedChains [][]*x509.Certificate - - // SignedCertificateTimestamps is a list of SCTs provided by the peer - // through the TLS handshake for the leaf certificate, if any. - SignedCertificateTimestamps [][]byte - - // OCSPResponse is a stapled Online Certificate Status Protocol (OCSP) - // response provided by the peer for the leaf certificate, if any. - OCSPResponse []byte - - // TLSUnique contains the "tls-unique" channel binding value (see RFC 5929, - // Section 3). This value will be nil for TLS 1.3 connections and for all - // resumed connections. - // - // Deprecated: there are conditions in which this value might not be unique - // to a connection. See the Security Considerations sections of RFC 5705 and - // RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings. - TLSUnique []byte - - // ekm is a closure exposed via ExportKeyingMaterial. - ekm func(label string, context []byte, length int) ([]byte, error) -} - -type ConnectionStateWith0RTT struct { - ConnectionState - - Used0RTT bool // true if 0-RTT was both offered and accepted -} - -// ClientAuthType is tls.ClientAuthType -type ClientAuthType = tls.ClientAuthType - -const ( - NoClientCert = tls.NoClientCert - RequestClientCert = tls.RequestClientCert - RequireAnyClientCert = tls.RequireAnyClientCert - VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven - RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert -) - -// requiresClientCert reports whether the ClientAuthType requires a client -// certificate to be provided. -func requiresClientCert(c ClientAuthType) bool { - switch c { - case RequireAnyClientCert, RequireAndVerifyClientCert: - return true - default: - return false - } -} - -// ClientSessionState contains the state needed by clients to resume TLS -// sessions. -type ClientSessionState = tls.ClientSessionState - -type clientSessionState struct { - sessionTicket []uint8 // Encrypted ticket used for session resumption with server - vers uint16 // TLS version negotiated for the session - cipherSuite uint16 // Ciphersuite negotiated for the session - masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret - serverCertificates []*x509.Certificate // Certificate chain presented by the server - verifiedChains [][]*x509.Certificate // Certificate chains we built for verification - receivedAt time.Time // When the session ticket was received from the server - ocspResponse []byte // Stapled OCSP response presented by the server - scts [][]byte // SCTs presented by the server - - // TLS 1.3 fields. - nonce []byte // Ticket nonce sent by the server, to derive PSK - useBy time.Time // Expiration of the ticket lifetime as set by the server - ageAdd uint32 // Random obfuscation factor for sending the ticket age -} - -// ClientSessionCache is a cache of ClientSessionState objects that can be used -// by a client to resume a TLS session with a given server. ClientSessionCache -// implementations should expect to be called concurrently from different -// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not -// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which -// are supported via this interface. -//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-17 ClientSessionCache" -type ClientSessionCache = tls.ClientSessionCache - -// SignatureScheme is a tls.SignatureScheme -type SignatureScheme = tls.SignatureScheme - -const ( - // RSASSA-PKCS1-v1_5 algorithms. - PKCS1WithSHA256 SignatureScheme = 0x0401 - PKCS1WithSHA384 SignatureScheme = 0x0501 - PKCS1WithSHA512 SignatureScheme = 0x0601 - - // RSASSA-PSS algorithms with public key OID rsaEncryption. - PSSWithSHA256 SignatureScheme = 0x0804 - PSSWithSHA384 SignatureScheme = 0x0805 - PSSWithSHA512 SignatureScheme = 0x0806 - - // ECDSA algorithms. Only constrained to a specific curve in TLS 1.3. - ECDSAWithP256AndSHA256 SignatureScheme = 0x0403 - ECDSAWithP384AndSHA384 SignatureScheme = 0x0503 - ECDSAWithP521AndSHA512 SignatureScheme = 0x0603 - - // EdDSA algorithms. - Ed25519 SignatureScheme = 0x0807 - - // Legacy signature and hash algorithms for TLS 1.2. - PKCS1WithSHA1 SignatureScheme = 0x0201 - ECDSAWithSHA1 SignatureScheme = 0x0203 -) - -// ClientHelloInfo contains information from a ClientHello message in order to -// guide application logic in the GetCertificate and GetConfigForClient callbacks. -type ClientHelloInfo = tls.ClientHelloInfo - -type clientHelloInfo struct { - // CipherSuites lists the CipherSuites supported by the client (e.g. - // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). - CipherSuites []uint16 - - // ServerName indicates the name of the server requested by the client - // in order to support virtual hosting. ServerName is only set if the - // client is using SNI (see RFC 4366, Section 3.1). - ServerName string - - // SupportedCurves lists the elliptic curves supported by the client. - // SupportedCurves is set only if the Supported Elliptic Curves - // Extension is being used (see RFC 4492, Section 5.1.1). - SupportedCurves []CurveID - - // SupportedPoints lists the point formats supported by the client. - // SupportedPoints is set only if the Supported Point Formats Extension - // is being used (see RFC 4492, Section 5.1.2). - SupportedPoints []uint8 - - // SignatureSchemes lists the signature and hash schemes that the client - // is willing to verify. SignatureSchemes is set only if the Signature - // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1). - SignatureSchemes []SignatureScheme - - // SupportedProtos lists the application protocols supported by the client. - // SupportedProtos is set only if the Application-Layer Protocol - // Negotiation Extension is being used (see RFC 7301, Section 3.1). - // - // Servers can select a protocol by setting Config.NextProtos in a - // GetConfigForClient return value. - SupportedProtos []string - - // SupportedVersions lists the TLS versions supported by the client. - // For TLS versions less than 1.3, this is extrapolated from the max - // version advertised by the client, so values other than the greatest - // might be rejected if used. - SupportedVersions []uint16 - - // Conn is the underlying net.Conn for the connection. Do not read - // from, or write to, this connection; that will cause the TLS - // connection to fail. - Conn net.Conn - - // config is embedded by the GetCertificate or GetConfigForClient caller, - // for use with SupportsCertificate. - config *Config - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *clientHelloInfo) Context() context.Context { - return c.ctx -} - -// CertificateRequestInfo contains information from a server's -// CertificateRequest message, which is used to demand a certificate and proof -// of control from a client. -type CertificateRequestInfo = tls.CertificateRequestInfo - -type certificateRequestInfo struct { - // AcceptableCAs contains zero or more, DER-encoded, X.501 - // Distinguished Names. These are the names of root or intermediate CAs - // that the server wishes the returned certificate to be signed by. An - // empty slice indicates that the server has no preference. - AcceptableCAs [][]byte - - // SignatureSchemes lists the signature schemes that the server is - // willing to verify. - SignatureSchemes []SignatureScheme - - // Version is the TLS version that was negotiated for this connection. - Version uint16 - - // ctx is the context of the handshake that is in progress. - ctx context.Context -} - -// Context returns the context of the handshake that is in progress. -// This context is a child of the context passed to HandshakeContext, -// if any, and is canceled when the handshake concludes. -func (c *certificateRequestInfo) Context() context.Context { - return c.ctx -} - -// RenegotiationSupport enumerates the different levels of support for TLS -// renegotiation. TLS renegotiation is the act of performing subsequent -// handshakes on a connection after the first. This significantly complicates -// the state machine and has been the source of numerous, subtle security -// issues. Initiating a renegotiation is not supported, but support for -// accepting renegotiation requests may be enabled. -// -// Even when enabled, the server may not change its identity between handshakes -// (i.e. the leaf certificate must be the same). Additionally, concurrent -// handshake and application data flow is not permitted so renegotiation can -// only be used with protocols that synchronise with the renegotiation, such as -// HTTPS. -// -// Renegotiation is not defined in TLS 1.3. -type RenegotiationSupport = tls.RenegotiationSupport - -const ( - // RenegotiateNever disables renegotiation. - RenegotiateNever = tls.RenegotiateNever - - // RenegotiateOnceAsClient allows a remote server to request - // renegotiation once per connection. - RenegotiateOnceAsClient = tls.RenegotiateOnceAsClient - - // RenegotiateFreelyAsClient allows a remote server to repeatedly - // request renegotiation. - RenegotiateFreelyAsClient = tls.RenegotiateFreelyAsClient -) - -// A Config structure is used to configure a TLS client or server. -// After one has been passed to a TLS function it must not be -// modified. A Config may be reused; the tls package will also not -// modify it. -type Config = tls.Config - -type config struct { - // Rand provides the source of entropy for nonces and RSA blinding. - // If Rand is nil, TLS uses the cryptographic random reader in package - // crypto/rand. - // The Reader must be safe for use by multiple goroutines. - Rand io.Reader - - // Time returns the current time as the number of seconds since the epoch. - // If Time is nil, TLS uses time.Now. - Time func() time.Time - - // Certificates contains one or more certificate chains to present to the - // other side of the connection. The first certificate compatible with the - // peer's requirements is selected automatically. - // - // Server configurations must set one of Certificates, GetCertificate or - // GetConfigForClient. Clients doing client-authentication may set either - // Certificates or GetClientCertificate. - // - // Note: if there are multiple Certificates, and they don't have the - // optional field Leaf set, certificate selection will incur a significant - // per-handshake performance cost. - Certificates []Certificate - - // NameToCertificate maps from a certificate name to an element of - // Certificates. Note that a certificate name can be of the form - // '*.example.com' and so doesn't have to be a domain name as such. - // - // Deprecated: NameToCertificate only allows associating a single - // certificate with a given name. Leave this field nil to let the library - // select the first compatible chain from Certificates. - NameToCertificate map[string]*Certificate - - // GetCertificate returns a Certificate based on the given - // ClientHelloInfo. It will only be called if the client supplies SNI - // information or if Certificates is empty. - // - // If GetCertificate is nil or returns nil, then the certificate is - // retrieved from NameToCertificate. If NameToCertificate is nil, the - // best element of Certificates will be used. - GetCertificate func(*ClientHelloInfo) (*Certificate, error) - - // GetClientCertificate, if not nil, is called when a server requests a - // certificate from a client. If set, the contents of Certificates will - // be ignored. - // - // If GetClientCertificate returns an error, the handshake will be - // aborted and that error will be returned. Otherwise - // GetClientCertificate must return a non-nil Certificate. If - // Certificate.Certificate is empty then no certificate will be sent to - // the server. If this is unacceptable to the server then it may abort - // the handshake. - // - // GetClientCertificate may be called multiple times for the same - // connection if renegotiation occurs or if TLS 1.3 is in use. - GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) - - // GetConfigForClient, if not nil, is called after a ClientHello is - // received from a client. It may return a non-nil Config in order to - // change the Config that will be used to handle this connection. If - // the returned Config is nil, the original Config will be used. The - // Config returned by this callback may not be subsequently modified. - // - // If GetConfigForClient is nil, the Config passed to Server() will be - // used for all connections. - // - // If SessionTicketKey was explicitly set on the returned Config, or if - // SetSessionTicketKeys was called on the returned Config, those keys will - // be used. Otherwise, the original Config keys will be used (and possibly - // rotated if they are automatically managed). - GetConfigForClient func(*ClientHelloInfo) (*Config, error) - - // VerifyPeerCertificate, if not nil, is called after normal - // certificate verification by either a TLS client or server. It - // receives the raw ASN.1 certificates provided by the peer and also - // any verified chains that normal processing found. If it returns a - // non-nil error, the handshake is aborted and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. If normal verification is disabled by - // setting InsecureSkipVerify, or (for a server) when ClientAuth is - // RequestClientCert or RequireAnyClientCert, then this callback will - // be considered but the verifiedChains argument will always be nil. - VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - - // VerifyConnection, if not nil, is called after normal certificate - // verification and after VerifyPeerCertificate by either a TLS client - // or server. If it returns a non-nil error, the handshake is aborted - // and that error results. - // - // If normal verification fails then the handshake will abort before - // considering this callback. This callback will run for all connections - // regardless of InsecureSkipVerify or ClientAuth settings. - VerifyConnection func(ConnectionState) error - - // RootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If RootCAs is nil, TLS uses the host's root CA set. - RootCAs *x509.CertPool - - // NextProtos is a list of supported application level protocols, in - // order of preference. If both peers support ALPN, the selected - // protocol will be one from this list, and the connection will fail - // if there is no mutually supported protocol. If NextProtos is empty - // or the peer doesn't support ALPN, the connection will succeed and - // ConnectionState.NegotiatedProtocol will be empty. - NextProtos []string - - // ServerName is used to verify the hostname on the returned - // certificates unless InsecureSkipVerify is given. It is also included - // in the client's handshake to support virtual hosting unless it is - // an IP address. - ServerName string - - // ClientAuth determines the server's policy for - // TLS Client Authentication. The default is NoClientCert. - ClientAuth ClientAuthType - - // ClientCAs defines the set of root certificate authorities - // that servers use if required to verify a client certificate - // by the policy in ClientAuth. - ClientCAs *x509.CertPool - - // InsecureSkipVerify controls whether a client verifies the server's - // certificate chain and host name. If InsecureSkipVerify is true, crypto/tls - // accepts any certificate presented by the server and any host name in that - // certificate. In this mode, TLS is susceptible to machine-in-the-middle - // attacks unless custom verification is used. This should be used only for - // testing or in combination with VerifyConnection or VerifyPeerCertificate. - InsecureSkipVerify bool - - // CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of - // the list is ignored. Note that TLS 1.3 ciphersuites are not configurable. - // - // If CipherSuites is nil, a safe default list is used. The default cipher - // suites might change over time. - CipherSuites []uint16 - - // PreferServerCipherSuites is a legacy field and has no effect. - // - // It used to control whether the server would follow the client's or the - // server's preference. Servers now select the best mutually supported - // cipher suite based on logic that takes into account inferred client - // hardware, server hardware, and security. - // - // Deprected: PreferServerCipherSuites is ignored. - PreferServerCipherSuites bool - - // SessionTicketsDisabled may be set to true to disable session ticket and - // PSK (resumption) support. Note that on clients, session ticket support is - // also disabled if ClientSessionCache is nil. - SessionTicketsDisabled bool - - // SessionTicketKey is used by TLS servers to provide session resumption. - // See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled - // with random data before the first server handshake. - // - // Deprecated: if this field is left at zero, session ticket keys will be - // automatically rotated every day and dropped after seven days. For - // customizing the rotation schedule or synchronizing servers that are - // terminating connections for the same host, use SetSessionTicketKeys. - SessionTicketKey [32]byte - - // ClientSessionCache is a cache of ClientSessionState entries for TLS - // session resumption. It is only used by clients. - ClientSessionCache ClientSessionCache - - // MinVersion contains the minimum TLS version that is acceptable. - // If zero, TLS 1.0 is currently taken as the minimum. - MinVersion uint16 - - // MaxVersion contains the maximum TLS version that is acceptable. - // If zero, the maximum version supported by this package is used, - // which is currently TLS 1.3. - MaxVersion uint16 - - // CurvePreferences contains the elliptic curves that will be used in - // an ECDHE handshake, in preference order. If empty, the default will - // be used. The client will use the first preference as the type for - // its key share in TLS 1.3. This may change in the future. - CurvePreferences []CurveID - - // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. - // When true, the largest possible TLS record size is always used. When - // false, the size of TLS records may be adjusted in an attempt to - // improve latency. - DynamicRecordSizingDisabled bool - - // Renegotiation controls what types of renegotiation are supported. - // The default, none, is correct for the vast majority of applications. - Renegotiation RenegotiationSupport - - // KeyLogWriter optionally specifies a destination for TLS master secrets - // in NSS key log format that can be used to allow external programs - // such as Wireshark to decrypt TLS connections. - // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. - // Use of KeyLogWriter compromises security and should only be - // used for debugging. - KeyLogWriter io.Writer - - // mutex protects sessionTicketKeys and autoSessionTicketKeys. - mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means the - // the keys were set with SessionTicketKey or SetSessionTicketKeys. The - // first key is used for new tickets and any subsequent keys can be used to - // decrypt old tickets. The slice contents are not protected by the mutex - // and are immutable. - sessionTicketKeys []ticketKey - // autoSessionTicketKeys is like sessionTicketKeys but is owned by the - // auto-rotation logic. See Config.ticketKeys. - autoSessionTicketKeys []ticketKey -} - -// A RecordLayer handles encrypting and decrypting of TLS messages. -type RecordLayer interface { - SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) - SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) - ReadHandshakeMessage() ([]byte, error) - WriteRecord([]byte) (int, error) - SendAlert(uint8) -} - -type ExtraConfig struct { - // GetExtensions, if not nil, is called before a message that allows - // sending of extensions is sent. - // Currently only implemented for the ClientHello message (for the client) - // and for the EncryptedExtensions message (for the server). - // Only valid for TLS 1.3. - GetExtensions func(handshakeMessageType uint8) []Extension - - // ReceivedExtensions, if not nil, is called when a message that allows the - // inclusion of extensions is received. - // It is called with an empty slice of extensions, if the message didn't - // contain any extensions. - // Currently only implemented for the ClientHello message (sent by the - // client) and for the EncryptedExtensions message (sent by the server). - // Only valid for TLS 1.3. - ReceivedExtensions func(handshakeMessageType uint8, exts []Extension) - - // AlternativeRecordLayer is used by QUIC - AlternativeRecordLayer RecordLayer - - // Enforce the selection of a supported application protocol. - // Only works for TLS 1.3. - // If enabled, client and server have to agree on an application protocol. - // Otherwise, connection establishment fails. - EnforceNextProtoSelection bool - - // If MaxEarlyData is greater than 0, the client will be allowed to send early - // data when resuming a session. - // Requires the AlternativeRecordLayer to be set. - // - // It has no meaning on the client. - MaxEarlyData uint32 - - // The Accept0RTT callback is called when the client offers 0-RTT. - // The server then has to decide if it wants to accept or reject 0-RTT. - // It is only used for servers. - Accept0RTT func(appData []byte) bool - - // 0RTTRejected is called when the server rejectes 0-RTT. - // It is only used for clients. - Rejected0RTT func() - - // If set, the client will export the 0-RTT key when resuming a session that - // allows sending of early data. - // Requires the AlternativeRecordLayer to be set. - // - // It has no meaning to the server. - Enable0RTT bool - - // Is called when the client saves a session ticket to the session ticket. - // This gives the application the opportunity to save some data along with the ticket, - // which can be restored when the session ticket is used. - GetAppDataForSessionState func() []byte - - // Is called when the client uses a session ticket. - // Restores the application data that was saved earlier on GetAppDataForSessionTicket. - SetAppDataFromSessionState func([]byte) -} - -// Clone clones. -func (c *ExtraConfig) Clone() *ExtraConfig { - return &ExtraConfig{ - GetExtensions: c.GetExtensions, - ReceivedExtensions: c.ReceivedExtensions, - AlternativeRecordLayer: c.AlternativeRecordLayer, - EnforceNextProtoSelection: c.EnforceNextProtoSelection, - MaxEarlyData: c.MaxEarlyData, - Enable0RTT: c.Enable0RTT, - Accept0RTT: c.Accept0RTT, - Rejected0RTT: c.Rejected0RTT, - GetAppDataForSessionState: c.GetAppDataForSessionState, - SetAppDataFromSessionState: c.SetAppDataFromSessionState, - } -} - -func (c *ExtraConfig) usesAlternativeRecordLayer() bool { - return c != nil && c.AlternativeRecordLayer != nil -} - -const ( - // ticketKeyNameLen is the number of bytes of identifier that is prepended to - // an encrypted session ticket in order to identify the key used to encrypt it. - ticketKeyNameLen = 16 - - // ticketKeyLifetime is how long a ticket key remains valid and can be used to - // resume a client connection. - ticketKeyLifetime = 7 * 24 * time.Hour // 7 days - - // ticketKeyRotation is how often the server should rotate the session ticket key - // that is used for new tickets. - ticketKeyRotation = 24 * time.Hour -) - -// ticketKey is the internal representation of a session ticket key. -type ticketKey struct { - // keyName is an opaque byte string that serves to identify the session - // ticket key. It's exposed as plaintext in every session ticket. - keyName [ticketKeyNameLen]byte - aesKey [16]byte - hmacKey [16]byte - // created is the time at which this ticket key was created. See Config.ticketKeys. - created time.Time -} - -// ticketKeyFromBytes converts from the external representation of a session -// ticket key to a ticketKey. Externally, session ticket keys are 32 random -// bytes and this function expands that into sufficient name and key material. -func (c *config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { - hashed := sha512.Sum512(b[:]) - copy(key.keyName[:], hashed[:ticketKeyNameLen]) - copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16]) - copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32]) - key.created = c.time() - return key -} - -// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session -// ticket, and the lifetime we set for tickets we send. -const maxSessionTicketLifetime = 7 * 24 * time.Hour - -// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is -// being used concurrently by a TLS client or server. -func (c *config) Clone() *config { - if c == nil { - return nil - } - c.mutex.RLock() - defer c.mutex.RUnlock() - return &config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - GetClientCertificate: c.GetClientCertificate, - GetConfigForClient: c.GetConfigForClient, - VerifyPeerCertificate: c.VerifyPeerCertificate, - VerifyConnection: c.VerifyConnection, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - KeyLogWriter: c.KeyLogWriter, - sessionTicketKeys: c.sessionTicketKeys, - autoSessionTicketKeys: c.autoSessionTicketKeys, - } -} - -// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was -// randomized for backwards compatibility but is not in use. -var deprecatedSessionTicketKey = []byte("DEPRECATED") - -// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is -// randomized if empty, and that sessionTicketKeys is populated from it otherwise. -func (c *config) initLegacySessionTicketKeyRLocked() { - // Don't write if SessionTicketKey is already defined as our deprecated string, - // or if it is defined by the user but sessionTicketKeys is already set. - if c.SessionTicketKey != [32]byte{} && - (bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) { - return - } - - // We need to write some data, so get an exclusive lock and re-check any conditions. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - if c.SessionTicketKey == [32]byte{} { - if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { - panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err)) - } - // Write the deprecated prefix at the beginning so we know we created - // it. This key with the DEPRECATED prefix isn't used as an actual - // session ticket key, and is only randomized in case the application - // reuses it for some reason. - copy(c.SessionTicketKey[:], deprecatedSessionTicketKey) - } else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 { - c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)} - } - -} - -// ticketKeys returns the ticketKeys for this connection. -// If configForClient has explicitly set keys, those will -// be returned. Otherwise, the keys on c will be used and -// may be rotated if auto-managed. -// During rotation, any expired session ticket keys are deleted from -// c.sessionTicketKeys. If the session ticket key that is currently -// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys) -// is not fresh, then a new session ticket key will be -// created and prepended to c.sessionTicketKeys. -func (c *config) ticketKeys(configForClient *config) []ticketKey { - // If the ConfigForClient callback returned a Config with explicitly set - // keys, use those, otherwise just use the original Config. - if configForClient != nil { - configForClient.mutex.RLock() - if configForClient.SessionTicketsDisabled { - return nil - } - configForClient.initLegacySessionTicketKeyRLocked() - if len(configForClient.sessionTicketKeys) != 0 { - ret := configForClient.sessionTicketKeys - configForClient.mutex.RUnlock() - return ret - } - configForClient.mutex.RUnlock() - } - - c.mutex.RLock() - defer c.mutex.RUnlock() - if c.SessionTicketsDisabled { - return nil - } - c.initLegacySessionTicketKeyRLocked() - if len(c.sessionTicketKeys) != 0 { - return c.sessionTicketKeys - } - // Fast path for the common case where the key is fresh enough. - if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation { - return c.autoSessionTicketKeys - } - - // autoSessionTicketKeys are managed by auto-rotation. - c.mutex.RUnlock() - defer c.mutex.RLock() - c.mutex.Lock() - defer c.mutex.Unlock() - // Re-check the condition in case it changed since obtaining the new lock. - if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation { - var newKey [32]byte - if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil { - panic(fmt.Sprintf("unable to generate random session ticket key: %v", err)) - } - valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1) - valid = append(valid, c.ticketKeyFromBytes(newKey)) - for _, k := range c.autoSessionTicketKeys { - // While rotating the current key, also remove any expired ones. - if c.time().Sub(k.created) < ticketKeyLifetime { - valid = append(valid, k) - } - } - c.autoSessionTicketKeys = valid - } - return c.autoSessionTicketKeys -} - -// SetSessionTicketKeys updates the session ticket keys for a server. -// -// The first key will be used when creating new tickets, while all keys can be -// used for decrypting tickets. It is safe to call this function while the -// server is running in order to rotate the session ticket keys. The function -// will panic if keys is empty. -// -// Calling this function will turn off automatic session ticket key rotation. -// -// If multiple servers are terminating connections for the same host they should -// all have the same session ticket keys. If the session ticket keys leaks, -// previously recorded and future TLS connections using those keys might be -// compromised. -func (c *config) SetSessionTicketKeys(keys [][32]byte) { - if len(keys) == 0 { - panic("tls: keys must have at least one key") - } - - newKeys := make([]ticketKey, len(keys)) - for i, bytes := range keys { - newKeys[i] = c.ticketKeyFromBytes(bytes) - } - - c.mutex.Lock() - c.sessionTicketKeys = newKeys - c.mutex.Unlock() -} - -func (c *config) rand() io.Reader { - r := c.Rand - if r == nil { - return rand.Reader - } - return r -} - -func (c *config) time() time.Time { - t := c.Time - if t == nil { - t = time.Now - } - return t() -} - -func (c *config) cipherSuites() []uint16 { - if c.CipherSuites != nil { - return c.CipherSuites - } - return defaultCipherSuites -} - -var supportedVersions = []uint16{ - VersionTLS13, - VersionTLS12, - VersionTLS11, - VersionTLS10, -} - -func (c *config) supportedVersions() []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if c != nil && c.MinVersion != 0 && v < c.MinVersion { - continue - } - if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -func (c *config) maxSupportedVersion() uint16 { - supportedVersions := c.supportedVersions() - if len(supportedVersions) == 0 { - return 0 - } - return supportedVersions[0] -} - -// supportedVersionsFromMax returns a list of supported versions derived from a -// legacy maximum version value. Note that only versions supported by this -// library are returned. Any newer peer will use supportedVersions anyway. -func supportedVersionsFromMax(maxVersion uint16) []uint16 { - versions := make([]uint16, 0, len(supportedVersions)) - for _, v := range supportedVersions { - if v > maxVersion { - continue - } - versions = append(versions, v) - } - return versions -} - -var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} - -func (c *config) curvePreferences() []CurveID { - if c == nil || len(c.CurvePreferences) == 0 { - return defaultCurvePreferences - } - return c.CurvePreferences -} - -func (c *config) supportsCurve(curve CurveID) bool { - for _, cc := range c.curvePreferences() { - if cc == curve { - return true - } - } - return false -} - -// mutualVersion returns the protocol version to use given the advertised -// versions of the peer. Priority is given to the peer preference order. -func (c *config) mutualVersion(peerVersions []uint16) (uint16, bool) { - supportedVersions := c.supportedVersions() - for _, peerVersion := range peerVersions { - for _, v := range supportedVersions { - if v == peerVersion { - return v, true - } - } - } - return 0, false -} - -var errNoCertificates = errors.New("tls: no certificates configured") - -// getCertificate returns the best certificate for the given ClientHelloInfo, -// defaulting to the first element of c.Certificates. -func (c *config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { - if c.GetCertificate != nil && - (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { - cert, err := c.GetCertificate(clientHello) - if cert != nil || err != nil { - return cert, err - } - } - - if len(c.Certificates) == 0 { - return nil, errNoCertificates - } - - if len(c.Certificates) == 1 { - // There's only one choice, so no point doing any work. - return &c.Certificates[0], nil - } - - if c.NameToCertificate != nil { - name := strings.ToLower(clientHello.ServerName) - if cert, ok := c.NameToCertificate[name]; ok { - return cert, nil - } - if len(name) > 0 { - labels := strings.Split(name, ".") - labels[0] = "*" - wildcardName := strings.Join(labels, ".") - if cert, ok := c.NameToCertificate[wildcardName]; ok { - return cert, nil - } - } - } - - for _, cert := range c.Certificates { - if err := clientHello.SupportsCertificate(&cert); err == nil { - return &cert, nil - } - } - - // If nothing matches, return the first certificate. - return &c.Certificates[0], nil -} - -// SupportsCertificate returns nil if the provided certificate is supported by -// the client that sent the ClientHello. Otherwise, it returns an error -// describing the reason for the incompatibility. -// -// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate -// callback, this method will take into account the associated Config. Note that -// if GetConfigForClient returns a different Config, the change can't be -// accounted for by this method. -// -// This function will call x509.ParseCertificate unless c.Leaf is set, which can -// incur a significant performance cost. -func (chi *clientHelloInfo) SupportsCertificate(c *Certificate) error { - // Note we don't currently support certificate_authorities nor - // signature_algorithms_cert, and don't check the algorithms of the - // signatures on the chain (which anyway are a SHOULD, see RFC 8446, - // Section 4.4.2.2). - - config := chi.config - if config == nil { - config = &Config{} - } - conf := fromConfig(config) - vers, ok := conf.mutualVersion(chi.SupportedVersions) - if !ok { - return errors.New("no mutually supported protocol versions") - } - - // If the client specified the name they are trying to connect to, the - // certificate needs to be valid for it. - if chi.ServerName != "" { - x509Cert, err := leafCertificate(c) - if err != nil { - return fmt.Errorf("failed to parse certificate: %w", err) - } - if err := x509Cert.VerifyHostname(chi.ServerName); err != nil { - return fmt.Errorf("certificate is not valid for requested server name: %w", err) - } - } - - // supportsRSAFallback returns nil if the certificate and connection support - // the static RSA key exchange, and unsupported otherwise. The logic for - // supporting static RSA is completely disjoint from the logic for - // supporting signed key exchanges, so we just check it as a fallback. - supportsRSAFallback := func(unsupported error) error { - // TLS 1.3 dropped support for the static RSA key exchange. - if vers == VersionTLS13 { - return unsupported - } - // The static RSA key exchange works by decrypting a challenge with the - // RSA private key, not by signing, so check the PrivateKey implements - // crypto.Decrypter, like *rsa.PrivateKey does. - if priv, ok := c.PrivateKey.(crypto.Decrypter); ok { - if _, ok := priv.Public().(*rsa.PublicKey); !ok { - return unsupported - } - } else { - return unsupported - } - // Finally, there needs to be a mutual cipher suite that uses the static - // RSA key exchange instead of ECDHE. - rsaCipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - return false - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if rsaCipherSuite == nil { - return unsupported - } - return nil - } - - // If the client sent the signature_algorithms extension, ensure it supports - // schemes we can use with this certificate and TLS version. - if len(chi.SignatureSchemes) > 0 { - if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil { - return supportsRSAFallback(err) - } - } - - // In TLS 1.3 we are done because supported_groups is only relevant to the - // ECDHE computation, point format negotiation is removed, cipher suites are - // only relevant to the AEAD choice, and static RSA does not exist. - if vers == VersionTLS13 { - return nil - } - - // The only signed key exchange we support is ECDHE. - if !supportsECDHE(conf, chi.SupportedCurves, chi.SupportedPoints) { - return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange")) - } - - var ecdsaCipherSuite bool - if priv, ok := c.PrivateKey.(crypto.Signer); ok { - switch pub := priv.Public().(type) { - case *ecdsa.PublicKey: - var curve CurveID - switch pub.Curve { - case elliptic.P256(): - curve = CurveP256 - case elliptic.P384(): - curve = CurveP384 - case elliptic.P521(): - curve = CurveP521 - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - var curveOk bool - for _, c := range chi.SupportedCurves { - if c == curve && conf.supportsCurve(c) { - curveOk = true - break - } - } - if !curveOk { - return errors.New("client doesn't support certificate curve") - } - ecdsaCipherSuite = true - case ed25519.PublicKey: - if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 { - return errors.New("connection doesn't support Ed25519") - } - ecdsaCipherSuite = true - case *rsa.PublicKey: - default: - return supportsRSAFallback(unsupportedCertificateError(c)) - } - } else { - return supportsRSAFallback(unsupportedCertificateError(c)) - } - - // Make sure that there is a mutually supported cipher suite that works with - // this certificate. Cipher suite selection will then apply the logic in - // reverse to pick it. See also serverHandshakeState.cipherSuiteOk. - cipherSuite := selectCipherSuite(chi.CipherSuites, conf.cipherSuites(), func(c *cipherSuite) bool { - if c.flags&suiteECDHE == 0 { - return false - } - if c.flags&suiteECSign != 0 { - if !ecdsaCipherSuite { - return false - } - } else { - if ecdsaCipherSuite { - return false - } - } - if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true - }) - if cipherSuite == nil { - return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate")) - } - - return nil -} - -// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate -// from the CommonName and SubjectAlternateName fields of each of the leaf -// certificates. -// -// Deprecated: NameToCertificate only allows associating a single certificate -// with a given name. Leave that field nil to let the library select the first -// compatible chain from Certificates. -func (c *config) BuildNameToCertificate() { - c.NameToCertificate = make(map[string]*Certificate) - for i := range c.Certificates { - cert := &c.Certificates[i] - x509Cert, err := leafCertificate(cert) - if err != nil { - continue - } - // If SANs are *not* present, some clients will consider the certificate - // valid for the name in the Common Name. - if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 { - c.NameToCertificate[x509Cert.Subject.CommonName] = cert - } - for _, san := range x509Cert.DNSNames { - c.NameToCertificate[san] = cert - } - } -} - -const ( - keyLogLabelTLS12 = "CLIENT_RANDOM" - keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET" - keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET" - keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0" - keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0" -) - -func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { - if c.KeyLogWriter == nil { - return nil - } - - logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) - - writerMutex.Lock() - _, err := c.KeyLogWriter.Write(logLine) - writerMutex.Unlock() - - return err -} - -// writerMutex protects all KeyLogWriters globally. It is rarely enabled, -// and is only for debugging, so a global mutex saves space. -var writerMutex sync.Mutex - -// A Certificate is a chain of one or more certificates, leaf first. -type Certificate = tls.Certificate - -// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing -// the corresponding c.Certificate[0]. -func leafCertificate(c *Certificate) (*x509.Certificate, error) { - if c.Leaf != nil { - return c.Leaf, nil - } - return x509.ParseCertificate(c.Certificate[0]) -} - -type handshakeMessage interface { - marshal() []byte - unmarshal([]byte) bool -} - -// lruSessionCache is a ClientSessionCache implementation that uses an LRU -// caching strategy. -type lruSessionCache struct { - sync.Mutex - - m map[string]*list.Element - q *list.List - capacity int -} - -type lruSessionCacheEntry struct { - sessionKey string - state *ClientSessionState -} - -// NewLRUClientSessionCache returns a ClientSessionCache with the given -// capacity that uses an LRU strategy. If capacity is < 1, a default capacity -// is used instead. -func NewLRUClientSessionCache(capacity int) ClientSessionCache { - const defaultSessionCacheCapacity = 64 - - if capacity < 1 { - capacity = defaultSessionCacheCapacity - } - return &lruSessionCache{ - m: make(map[string]*list.Element), - q: list.New(), - capacity: capacity, - } -} - -// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry -// corresponding to sessionKey is removed from the cache instead. -func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - if cs == nil { - c.q.Remove(elem) - delete(c.m, sessionKey) - } else { - entry := elem.Value.(*lruSessionCacheEntry) - entry.state = cs - c.q.MoveToFront(elem) - } - return - } - - if c.q.Len() < c.capacity { - entry := &lruSessionCacheEntry{sessionKey, cs} - c.m[sessionKey] = c.q.PushFront(entry) - return - } - - elem := c.q.Back() - entry := elem.Value.(*lruSessionCacheEntry) - delete(c.m, entry.sessionKey) - entry.sessionKey = sessionKey - entry.state = cs - c.q.MoveToFront(elem) - c.m[sessionKey] = elem -} - -// Get returns the ClientSessionState value associated with a given key. It -// returns (nil, false) if no value is found. -func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { - c.Lock() - defer c.Unlock() - - if elem, ok := c.m[sessionKey]; ok { - c.q.MoveToFront(elem) - return elem.Value.(*lruSessionCacheEntry).state, true - } - return nil, false -} - -var emptyConfig Config - -func defaultConfig() *Config { - return &emptyConfig -} - -func unexpectedMessageError(wanted, got interface{}) error { - return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) -} - -func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { - for _, s := range supportedSignatureAlgorithms { - if s == sigAlg { - return true - } - } - return false -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/conn.go b/vendor/github.com/marten-seemann/qtls-go1-17/conn.go deleted file mode 100644 index 70fad4652d1..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/conn.go +++ /dev/null @@ -1,1601 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// TLS low level connection and record layer - -package qtls - -import ( - "bytes" - "context" - "crypto/cipher" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "sync" - "sync/atomic" - "time" -) - -// A Conn represents a secured connection. -// It implements the net.Conn interface. -type Conn struct { - // constant - conn net.Conn - isClient bool - handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - - // handshakeStatus is 1 if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - // handshakeStatus == 1 implies handshakeErr == nil. - // This field is only to be accessed with sync/atomic. - handshakeStatus uint32 - // constant after handshake; protected by handshakeMutex - handshakeMutex sync.Mutex - handshakeErr error // error resulting from handshake - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *config // configuration passed to constructor - // handshakes counts the number of handshakes performed on the - // connection so far. If renegotiation is disabled then this is either - // zero or one. - extraConfig *ExtraConfig - - handshakes int - didResume bool // whether this connection was a session resumption - cipherSuite uint16 - ocspResponse []byte // stapled OCSP response - scts [][]byte // signed certificate timestamps from server - peerCertificates []*x509.Certificate - // verifiedChains contains the certificate chains that we built, as - // opposed to the ones presented by the server. - verifiedChains [][]*x509.Certificate - // serverName contains the server name indicated by the client, if any. - serverName string - // secureRenegotiation is true if the server echoed the secure - // renegotiation extension. (This is meaningless as a server because - // renegotiation is not supported in that case.) - secureRenegotiation bool - // ekm is a closure for exporting keying material. - ekm func(label string, context []byte, length int) ([]byte, error) - // For the client: - // resumptionSecret is the resumption_master_secret for handling - // NewSessionTicket messages. nil if config.SessionTicketsDisabled. - // For the server: - // resumptionSecret is the resumption_master_secret for generating - // NewSessionTicket messages. Only used when the alternative record - // layer is set. nil if config.SessionTicketsDisabled. - resumptionSecret []byte - - // ticketKeys is the set of active session ticket keys for this - // connection. The first one is used to encrypt new tickets and - // all are tried to decrypt tickets. - ticketKeys []ticketKey - - // clientFinishedIsFirst is true if the client sent the first Finished - // message during the most recent handshake. This is recorded because - // the first transmitted Finished message is the tls-unique - // channel-binding value. - clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - - // clientFinished and serverFinished contain the Finished message sent - // by the client or server in the most recent handshake. This is - // retained to support the renegotiation extension and tls-unique - // channel-binding. - clientFinished [12]byte - serverFinished [12]byte - - // clientProtocol is the negotiated ALPN protocol. - clientProtocol string - - // input/output - in, out halfConn - rawInput bytes.Buffer // raw input, starting with a record header - input bytes.Reader // application data waiting to be read, from rawInput.Next - hand bytes.Buffer // handshake data waiting to be read - buffering bool // whether records are buffered in sendBuf - sendBuf []byte // a buffer of records waiting to be sent - - // bytesSent counts the bytes of application data sent. - // packetsSent counts packets. - bytesSent int64 - packetsSent int64 - - // retryCount counts the number of consecutive non-advancing records - // received by Conn.readRecord. That is, records that neither advance the - // handshake, nor deliver application data. Protected by in.Mutex. - retryCount int - - // activeCall is an atomic int32; the low bit is whether Close has - // been called. the rest of the bits are the number of goroutines - // in Conn.Write. - activeCall int32 - - used0RTT bool - - tmp [16]byte -} - -// Access to net.Conn methods. -// Cannot just embed net.Conn because that would -// export the struct field too. - -// LocalAddr returns the local network address. -func (c *Conn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network address. -func (c *Conn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline sets the read and write deadlines associated with the connection. -// A zero value for t means Read and Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -// SetReadDeadline sets the read deadline on the underlying connection. -// A zero value for t means Read will not time out. -func (c *Conn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the write deadline on the underlying connection. -// A zero value for t means Write will not time out. -// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. -func (c *Conn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} - -// A halfConn represents one direction of the record layer -// connection, either sending or receiving. -type halfConn struct { - sync.Mutex - - err error // first permanent error - version uint16 // protocol version - cipher interface{} // cipher algorithm - mac hash.Hash - seq [8]byte // 64-bit sequence number - - scratchBuf [13]byte // to avoid allocs; interface method args escape - - nextCipher interface{} // next encryption state - nextMac hash.Hash // next MAC algorithm - - trafficSecret []byte // current TLS 1.3 traffic secret - - setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) -} - -type permanentError struct { - err net.Error -} - -func (e *permanentError) Error() string { return e.err.Error() } -func (e *permanentError) Unwrap() error { return e.err } -func (e *permanentError) Timeout() bool { return e.err.Timeout() } -func (e *permanentError) Temporary() bool { return false } - -func (hc *halfConn) setErrorLocked(err error) error { - if e, ok := err.(net.Error); ok { - hc.err = &permanentError{err: e} - } else { - hc.err = err - } - return hc.err -} - -// prepareCipherSpec sets the encryption and MAC states -// that a subsequent changeCipherSpec will use. -func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) { - hc.version = version - hc.nextCipher = cipher - hc.nextMac = mac -} - -// changeCipherSpec changes the encryption and MAC states -// to the ones previously passed to prepareCipherSpec. -func (hc *halfConn) changeCipherSpec() error { - if hc.nextCipher == nil || hc.version == VersionTLS13 { - return alertInternalError - } - hc.cipher = hc.nextCipher - hc.mac = hc.nextMac - hc.nextCipher = nil - hc.nextMac = nil - for i := range hc.seq { - hc.seq[i] = 0 - } - return nil -} - -func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) { - if hc.setKeyCallback != nil { - s := &CipherSuiteTLS13{ - ID: suite.id, - KeyLen: suite.keyLen, - Hash: suite.hash, - AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) }, - } - hc.setKeyCallback(encLevel, s, trafficSecret) - } -} - -func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) { - hc.trafficSecret = secret - key, iv := suite.trafficKey(secret) - hc.cipher = suite.aead(key, iv) - for i := range hc.seq { - hc.seq[i] = 0 - } -} - -// incSeq increments the sequence number. -func (hc *halfConn) incSeq() { - for i := 7; i >= 0; i-- { - hc.seq[i]++ - if hc.seq[i] != 0 { - return - } - } - - // Not allowed to let sequence number wrap. - // Instead, must renegotiate before it does. - // Not likely enough to bother. - panic("TLS: sequence number wraparound") -} - -// explicitNonceLen returns the number of bytes of explicit nonce or IV included -// in each record. Explicit nonces are present only in CBC modes after TLS 1.0 -// and in certain AEAD modes in TLS 1.2. -func (hc *halfConn) explicitNonceLen() int { - if hc.cipher == nil { - return 0 - } - - switch c := hc.cipher.(type) { - case cipher.Stream: - return 0 - case aead: - return c.explicitNonceLen() - case cbcMode: - // TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack. - if hc.version >= VersionTLS11 { - return c.BlockSize() - } - return 0 - default: - panic("unknown cipher type") - } -} - -// extractPadding returns, in constant time, the length of the padding to remove -// from the end of payload. It also returns a byte which is equal to 255 if the -// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. -func extractPadding(payload []byte) (toRemove int, good byte) { - if len(payload) < 1 { - return 0, 0 - } - - paddingLen := payload[len(payload)-1] - t := uint(len(payload)-1) - uint(paddingLen) - // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good = byte(int32(^t) >> 31) - - // The maximum possible padding length plus the actual length field - toCheck := 256 - // The length of the padded data is public, so we can use an if here - if toCheck > len(payload) { - toCheck = len(payload) - } - - for i := 0; i < toCheck; i++ { - t := uint(paddingLen) - uint(i) - // if i <= paddingLen then the MSB of t is zero - mask := byte(int32(^t) >> 31) - b := payload[len(payload)-1-i] - good &^= mask&paddingLen ^ mask&b - } - - // We AND together the bits of good and replicate the result across - // all the bits. - good &= good << 4 - good &= good << 2 - good &= good << 1 - good = uint8(int8(good) >> 7) - - // Zero the padding length on error. This ensures any unchecked bytes - // are included in the MAC. Otherwise, an attacker that could - // distinguish MAC failures from padding failures could mount an attack - // similar to POODLE in SSL 3.0: given a good ciphertext that uses a - // full block's worth of padding, replace the final block with another - // block. If the MAC check passed but the padding check failed, the - // last byte of that block decrypted to the block size. - // - // See also macAndPaddingGood logic below. - paddingLen &= good - - toRemove = int(paddingLen) + 1 - return -} - -func roundUp(a, b int) int { - return a + (b-a%b)%b -} - -// cbcMode is an interface for block ciphers using cipher block chaining. -type cbcMode interface { - cipher.BlockMode - SetIV([]byte) -} - -// decrypt authenticates and decrypts the record if protection is active at -// this stage. The returned plaintext might overlap with the input. -func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { - var plaintext []byte - typ := recordType(record[0]) - payload := record[recordHeaderLen:] - - // In TLS 1.3, change_cipher_spec messages are to be ignored without being - // decrypted. See RFC 8446, Appendix D.4. - if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { - return payload, typ, nil - } - - paddingGood := byte(255) - paddingLen := 0 - - explicitNonceLen := hc.explicitNonceLen() - - if hc.cipher != nil { - switch c := hc.cipher.(type) { - case cipher.Stream: - c.XORKeyStream(payload, payload) - case aead: - if len(payload) < explicitNonceLen { - return nil, 0, alertBadRecordMAC - } - nonce := payload[:explicitNonceLen] - if len(nonce) == 0 { - nonce = hc.seq[:] - } - payload = payload[explicitNonceLen:] - - var additionalData []byte - if hc.version == VersionTLS13 { - additionalData = record[:recordHeaderLen] - } else { - additionalData = append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:3]...) - n := len(payload) - c.Overhead() - additionalData = append(additionalData, byte(n>>8), byte(n)) - } - - var err error - plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) - if err != nil { - return nil, 0, alertBadRecordMAC - } - case cbcMode: - blockSize := c.BlockSize() - minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) - if len(payload)%blockSize != 0 || len(payload) < minPayload { - return nil, 0, alertBadRecordMAC - } - - if explicitNonceLen > 0 { - c.SetIV(payload[:explicitNonceLen]) - payload = payload[explicitNonceLen:] - } - c.CryptBlocks(payload, payload) - - // In a limited attempt to protect against CBC padding oracles like - // Lucky13, the data past paddingLen (which is secret) is passed to - // the MAC function as extra data, to be fed into the HMAC after - // computing the digest. This makes the MAC roughly constant time as - // long as the digest computation is constant time and does not - // affect the subsequent write, modulo cache effects. - paddingLen, paddingGood = extractPadding(payload) - default: - panic("unknown cipher type") - } - - if hc.version == VersionTLS13 { - if typ != recordTypeApplicationData { - return nil, 0, alertUnexpectedMessage - } - if len(plaintext) > maxPlaintext+1 { - return nil, 0, alertRecordOverflow - } - // Remove padding and find the ContentType scanning from the end. - for i := len(plaintext) - 1; i >= 0; i-- { - if plaintext[i] != 0 { - typ = recordType(plaintext[i]) - plaintext = plaintext[:i] - break - } - if i == 0 { - return nil, 0, alertUnexpectedMessage - } - } - } - } else { - plaintext = payload - } - - if hc.mac != nil { - macSize := hc.mac.Size() - if len(payload) < macSize { - return nil, 0, alertBadRecordMAC - } - - n := len(payload) - macSize - paddingLen - n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } - record[3] = byte(n >> 8) - record[4] = byte(n) - remoteMAC := payload[n : n+macSize] - localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) - - // This is equivalent to checking the MACs and paddingGood - // separately, but in constant-time to prevent distinguishing - // padding failures from MAC failures. Depending on what value - // of paddingLen was returned on bad padding, distinguishing - // bad MAC from bad padding can lead to an attack. - // - // See also the logic at the end of extractPadding. - macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood) - if macAndPaddingGood != 1 { - return nil, 0, alertBadRecordMAC - } - - plaintext = payload[:n] - } - - hc.incSeq() - return plaintext, typ, nil -} - -func (c *Conn) setAlternativeRecordLayer() { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey - c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey - } -} - -// sliceForAppend extends the input slice by n bytes. head is the full extended -// slice, while tail is the appended part. If the original slice has sufficient -// capacity no allocation is performed. -func sliceForAppend(in []byte, n int) (head, tail []byte) { - if total := len(in) + n; cap(in) >= total { - head = in[:total] - } else { - head = make([]byte, total) - copy(head, in) - } - tail = head[len(in):] - return -} - -// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and -// appends it to record, which must already contain the record header. -func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) { - if hc.cipher == nil { - return append(record, payload...), nil - } - - var explicitNonce []byte - if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 { - record, explicitNonce = sliceForAppend(record, explicitNonceLen) - if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 { - // The AES-GCM construction in TLS has an explicit nonce so that the - // nonce can be random. However, the nonce is only 8 bytes which is - // too small for a secure, random nonce. Therefore we use the - // sequence number as the nonce. The 3DES-CBC construction also has - // an 8 bytes nonce but its nonces must be unpredictable (see RFC - // 5246, Appendix F.3), forcing us to use randomness. That's not - // 3DES' biggest problem anyway because the birthday bound on block - // collision is reached first due to its similarly small block size - // (see the Sweet32 attack). - copy(explicitNonce, hc.seq[:]) - } else { - if _, err := io.ReadFull(rand, explicitNonce); err != nil { - return nil, err - } - } - } - - var dst []byte - switch c := hc.cipher.(type) { - case cipher.Stream: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - record, dst = sliceForAppend(record, len(payload)+len(mac)) - c.XORKeyStream(dst[:len(payload)], payload) - c.XORKeyStream(dst[len(payload):], mac) - case aead: - nonce := explicitNonce - if len(nonce) == 0 { - nonce = hc.seq[:] - } - - if hc.version == VersionTLS13 { - record = append(record, payload...) - - // Encrypt the actual ContentType and replace the plaintext one. - record = append(record, record[0]) - record[0] = byte(recordTypeApplicationData) - - n := len(payload) + 1 + c.Overhead() - record[3] = byte(n >> 8) - record[4] = byte(n) - - record = c.Seal(record[:recordHeaderLen], - nonce, record[recordHeaderLen:], record[:recordHeaderLen]) - } else { - additionalData := append(hc.scratchBuf[:0], hc.seq[:]...) - additionalData = append(additionalData, record[:recordHeaderLen]...) - record = c.Seal(record, nonce, payload, additionalData) - } - case cbcMode: - mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil) - blockSize := c.BlockSize() - plaintextLen := len(payload) + len(mac) - paddingLen := blockSize - plaintextLen%blockSize - record, dst = sliceForAppend(record, plaintextLen+paddingLen) - copy(dst, payload) - copy(dst[len(payload):], mac) - for i := plaintextLen; i < len(dst); i++ { - dst[i] = byte(paddingLen - 1) - } - if len(explicitNonce) > 0 { - c.SetIV(explicitNonce) - } - c.CryptBlocks(dst, dst) - default: - panic("unknown cipher type") - } - - // Update length to include nonce, MAC and any block padding needed. - n := len(record) - recordHeaderLen - record[3] = byte(n >> 8) - record[4] = byte(n) - hc.incSeq() - - return record, nil -} - -// RecordHeaderError is returned when a TLS record header is invalid. -type RecordHeaderError struct { - // Msg contains a human readable string that describes the error. - Msg string - // RecordHeader contains the five bytes of TLS record header that - // triggered the error. - RecordHeader [5]byte - // Conn provides the underlying net.Conn in the case that a client - // sent an initial handshake that didn't look like TLS. - // It is nil if there's already been a handshake or a TLS alert has - // been written to the connection. - Conn net.Conn -} - -func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } - -func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { - err.Msg = msg - err.Conn = conn - copy(err.RecordHeader[:], c.rawInput.Bytes()) - return err -} - -func (c *Conn) readRecord() error { - return c.readRecordOrCCS(false) -} - -func (c *Conn) readChangeCipherSpec() error { - return c.readRecordOrCCS(true) -} - -// readRecordOrCCS reads one or more TLS records from the connection and -// updates the record layer state. Some invariants: -// * c.in must be locked -// * c.input must be empty -// During the handshake one and only one of the following will happen: -// - c.hand grows -// - c.in.changeCipherSpec is called -// - an error is returned -// After the handshake one and only one of the following will happen: -// - c.hand grows -// - c.input is set -// - an error is returned -func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { - if c.in.err != nil { - return c.in.err - } - handshakeComplete := c.handshakeComplete() - - // This function modifies c.rawInput, which owns the c.input memory. - if c.input.Len() != 0 { - return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data")) - } - c.input.Reset(nil) - - // Read header, payload. - if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil { - // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify - // is an error, but popular web sites seem to do this, so we accept it - // if and only if at the record boundary. - if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 { - err = io.EOF - } - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - hdr := c.rawInput.Bytes()[:recordHeaderLen] - typ := recordType(hdr[0]) - - // No valid TLS record has a type of 0x80, however SSLv2 handshakes - // start with a uint16 length where the MSB is set and the first record - // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests - // an SSLv2 client. - if !handshakeComplete && typ == 0x80 { - c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) - } - - vers := uint16(hdr[1])<<8 | uint16(hdr[2]) - n := int(hdr[3])<<8 | int(hdr[4]) - if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { - c.sendAlert(alertProtocolVersion) - msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if !c.haveVers { - // First message, be extra suspicious: this might not be a TLS - // client. Bail out before reading a full 'body', if possible. - // The current max version is 3.3 so if the version is >= 16.0, - // it's probably not real. - if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 { - return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) - } - } - if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { - c.sendAlert(alertRecordOverflow) - msg := fmt.Sprintf("oversized record received with length %d", n) - return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) - } - if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { - if e, ok := err.(net.Error); !ok || !e.Temporary() { - c.in.setErrorLocked(err) - } - return err - } - - // Process message. - record := c.rawInput.Next(recordHeaderLen + n) - data, typ, err := c.in.decrypt(record) - if err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - if len(data) > maxPlaintext { - return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) - } - - // Application Data messages are always protected. - if c.in.cipher == nil && typ == recordTypeApplicationData { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { - // This is a state-advancing message: reset the retry count. - c.retryCount = 0 - } - - // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. - if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - switch typ { - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - - case recordTypeAlert: - if len(data) != 2 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if alert(data[1]) == alertCloseNotify { - return c.in.setErrorLocked(io.EOF) - } - if c.vers == VersionTLS13 { - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - } - switch data[0] { - case alertLevelWarning: - // Drop the record on the floor and retry. - return c.retryReadRecord(expectChangeCipherSpec) - case alertLevelError: - return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) - default: - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - case recordTypeChangeCipherSpec: - if len(data) != 1 || data[0] != 1 { - return c.in.setErrorLocked(c.sendAlert(alertDecodeError)) - } - // Handshake messages are not allowed to fragment across the CCS. - if c.hand.Len() > 0 { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // In TLS 1.3, change_cipher_spec records are ignored until the - // Finished. See RFC 8446, Appendix D.4. Note that according to Section - // 5, a server can send a ChangeCipherSpec before its ServerHello, when - // c.vers is still unset. That's not useful though and suspicious if the - // server then selects a lower protocol version, so don't allow that. - if c.vers == VersionTLS13 { - return c.retryReadRecord(expectChangeCipherSpec) - } - if !expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - if err := c.in.changeCipherSpec(); err != nil { - return c.in.setErrorLocked(c.sendAlert(err.(alert))) - } - - case recordTypeApplicationData: - if !handshakeComplete || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - // Some OpenSSL servers send empty records in order to randomize the - // CBC IV. Ignore a limited number of empty records. - if len(data) == 0 { - return c.retryReadRecord(expectChangeCipherSpec) - } - // Note that data is owned by c.rawInput, following the Next call above, - // to avoid copying the plaintext. This is safe because c.rawInput is - // not read from or written to until c.input is drained. - c.input.Reset(data) - - case recordTypeHandshake: - if len(data) == 0 || expectChangeCipherSpec { - return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - c.hand.Write(data) - } - - return nil -} - -// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like -// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. -func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many ignored records")) - } - return c.readRecordOrCCS(expectChangeCipherSpec) -} - -// atLeastReader reads from R, stopping with EOF once at least N bytes have been -// read. It is different from an io.LimitedReader in that it doesn't cut short -// the last Read call, and in that it considers an early EOF an error. -type atLeastReader struct { - R io.Reader - N int64 -} - -func (r *atLeastReader) Read(p []byte) (int, error) { - if r.N <= 0 { - return 0, io.EOF - } - n, err := r.R.Read(p) - r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 - if r.N > 0 && err == io.EOF { - return n, io.ErrUnexpectedEOF - } - if r.N <= 0 && err == nil { - return n, io.EOF - } - return n, err -} - -// readFromUntil reads from r into c.rawInput until c.rawInput contains -// at least n bytes or else returns an error. -func (c *Conn) readFromUntil(r io.Reader, n int) error { - if c.rawInput.Len() >= n { - return nil - } - needs := n - c.rawInput.Len() - // There might be extra input waiting on the wire. Make a best effort - // attempt to fetch it so that it can be used in (*Conn).Read to - // "predict" closeNotify alerts. - c.rawInput.Grow(needs + bytes.MinRead) - _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) - return err -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlertLocked(err alert) error { - switch err { - case alertNoRenegotiation, alertCloseNotify: - c.tmp[0] = alertLevelWarning - default: - c.tmp[0] = alertLevelError - } - c.tmp[1] = byte(err) - - _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2]) - if err == alertCloseNotify { - // closeNotify is a special case in that it isn't an error. - return writeErr - } - - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) -} - -// sendAlert sends a TLS alert message. -func (c *Conn) sendAlert(err alert) error { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err)) - return &net.OpError{Op: "local error", Err: err} - } - - c.out.Lock() - defer c.out.Unlock() - return c.sendAlertLocked(err) -} - -const ( - // tcpMSSEstimate is a conservative estimate of the TCP maximum segment - // size (MSS). A constant is used, rather than querying the kernel for - // the actual MSS, to avoid complexity. The value here is the IPv6 - // minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40 - // bytes) and a TCP header with timestamps (32 bytes). - tcpMSSEstimate = 1208 - - // recordSizeBoostThreshold is the number of bytes of application data - // sent after which the TLS record size will be increased to the - // maximum. - recordSizeBoostThreshold = 128 * 1024 -) - -// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the -// next application data record. There is the following trade-off: -// -// - For latency-sensitive applications, such as web browsing, each TLS -// record should fit in one TCP segment. -// - For throughput-sensitive applications, such as large file transfers, -// larger TLS records better amortize framing and encryption overheads. -// -// A simple heuristic that works well in practice is to use small records for -// the first 1MB of data, then use larger records for subsequent data, and -// reset back to smaller records after the connection becomes idle. See "High -// Performance Web Networking", Chapter 4, or: -// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/ -// -// In the interests of simplicity and determinism, this code does not attempt -// to reset the record size once the connection is idle, however. -func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { - if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { - return maxPlaintext - } - - if c.bytesSent >= recordSizeBoostThreshold { - return maxPlaintext - } - - // Subtract TLS overheads to get the maximum payload size. - payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen() - if c.out.cipher != nil { - switch ciph := c.out.cipher.(type) { - case cipher.Stream: - payloadBytes -= c.out.mac.Size() - case cipher.AEAD: - payloadBytes -= ciph.Overhead() - case cbcMode: - blockSize := ciph.BlockSize() - // The payload must fit in a multiple of blockSize, with - // room for at least one padding byte. - payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 - // The MAC is appended before padding so affects the - // payload size directly. - payloadBytes -= c.out.mac.Size() - default: - panic("unknown cipher type") - } - } - if c.vers == VersionTLS13 { - payloadBytes-- // encrypted ContentType - } - - // Allow packet growth in arithmetic progression up to max. - pkt := c.packetsSent - c.packetsSent++ - if pkt > 1000 { - return maxPlaintext // avoid overflow in multiply below - } - - n := payloadBytes * int(pkt+1) - if n > maxPlaintext { - n = maxPlaintext - } - return n -} - -func (c *Conn) write(data []byte) (int, error) { - if c.buffering { - c.sendBuf = append(c.sendBuf, data...) - return len(data), nil - } - - n, err := c.conn.Write(data) - c.bytesSent += int64(n) - return n, err -} - -func (c *Conn) flush() (int, error) { - if len(c.sendBuf) == 0 { - return 0, nil - } - - n, err := c.conn.Write(c.sendBuf) - c.bytesSent += int64(n) - c.sendBuf = nil - c.buffering = false - return n, err -} - -// outBufPool pools the record-sized scratch buffers used by writeRecordLocked. -var outBufPool = sync.Pool{ - New: func() interface{} { - return new([]byte) - }, -} - -// writeRecordLocked writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { - outBufPtr := outBufPool.Get().(*[]byte) - outBuf := *outBufPtr - defer func() { - // You might be tempted to simplify this by just passing &outBuf to Put, - // but that would make the local copy of the outBuf slice header escape - // to the heap, causing an allocation. Instead, we keep around the - // pointer to the slice header returned by Get, which is already on the - // heap, and overwrite and return that. - *outBufPtr = outBuf - outBufPool.Put(outBufPtr) - }() - - var n int - for len(data) > 0 { - m := len(data) - if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload { - m = maxPayload - } - - _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen) - outBuf[0] = byte(typ) - vers := c.vers - if vers == 0 { - // Some TLS servers fail if the record version is - // greater than TLS 1.0 for the initial ClientHello. - vers = VersionTLS10 - } else if vers == VersionTLS13 { - // TLS 1.3 froze the record layer version to 1.2. - // See RFC 8446, Section 5.1. - vers = VersionTLS12 - } - outBuf[1] = byte(vers >> 8) - outBuf[2] = byte(vers) - outBuf[3] = byte(m >> 8) - outBuf[4] = byte(m) - - var err error - outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand()) - if err != nil { - return n, err - } - if _, err := c.write(outBuf); err != nil { - return n, err - } - n += m - data = data[m:] - } - - if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 { - if err := c.out.changeCipherSpec(); err != nil { - return n, c.sendAlertLocked(err.(alert)) - } - } - - return n, nil -} - -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - if typ == recordTypeChangeCipherSpec { - return len(data), nil - } - return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) - } - - c.out.Lock() - defer c.out.Unlock() - - return c.writeRecordLocked(typ, data) -} - -// readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (interface{}, error) { - var data []byte - if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - var err error - data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage() - if err != nil { - return nil, err - } - } else { - for c.hand.Len() < 4 { - if err := c.readRecord(); err != nil { - return nil, err - } - } - - data = c.hand.Bytes() - n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if n > maxHandshake { - c.sendAlertLocked(alertInternalError) - return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)) - } - for c.hand.Len() < 4+n { - if err := c.readRecord(); err != nil { - return nil, err - } - } - data = c.hand.Next(4 + n) - } - var m handshakeMessage - switch data[0] { - case typeHelloRequest: - m = new(helloRequestMsg) - case typeClientHello: - m = new(clientHelloMsg) - case typeServerHello: - m = new(serverHelloMsg) - case typeNewSessionTicket: - if c.vers == VersionTLS13 { - m = new(newSessionTicketMsgTLS13) - } else { - m = new(newSessionTicketMsg) - } - case typeCertificate: - if c.vers == VersionTLS13 { - m = new(certificateMsgTLS13) - } else { - m = new(certificateMsg) - } - case typeCertificateRequest: - if c.vers == VersionTLS13 { - m = new(certificateRequestMsgTLS13) - } else { - m = &certificateRequestMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - } - case typeCertificateStatus: - m = new(certificateStatusMsg) - case typeServerKeyExchange: - m = new(serverKeyExchangeMsg) - case typeServerHelloDone: - m = new(serverHelloDoneMsg) - case typeClientKeyExchange: - m = new(clientKeyExchangeMsg) - case typeCertificateVerify: - m = &certificateVerifyMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, - } - case typeFinished: - m = new(finishedMsg) - case typeEncryptedExtensions: - m = new(encryptedExtensionsMsg) - case typeEndOfEarlyData: - m = new(endOfEarlyDataMsg) - case typeKeyUpdate: - m = new(keyUpdateMsg) - default: - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - - // The handshake message unmarshalers - // expect to be able to keep references to data, - // so pass in a fresh copy that won't be overwritten. - data = append([]byte(nil), data...) - - if !m.unmarshal(data) { - return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - } - return m, nil -} - -var ( - errShutdown = errors.New("tls: protocol is shutdown") -) - -// Write writes data to the connection. -// -// As Write calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Write is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Write(b []byte) (int, error) { - // interlock with Close below - for { - x := atomic.LoadInt32(&c.activeCall) - if x&1 != 0 { - return 0, net.ErrClosed - } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { - break - } - } - defer atomic.AddInt32(&c.activeCall, -2) - - if err := c.Handshake(); err != nil { - return 0, err - } - - c.out.Lock() - defer c.out.Unlock() - - if err := c.out.err; err != nil { - return 0, err - } - - if !c.handshakeComplete() { - return 0, alertInternalError - } - - if c.closeNotifySent { - return 0, errShutdown - } - - // TLS 1.0 is susceptible to a chosen-plaintext - // attack when using block mode ciphers due to predictable IVs. - // This can be prevented by splitting each Application Data - // record into two records, effectively randomizing the IV. - // - // https://www.openssl.org/~bodo/tls-cbc.txt - // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 - // https://www.imperialviolet.org/2012/01/15/beastfollowup.html - - var m int - if len(b) > 1 && c.vers == VersionTLS10 { - if _, ok := c.out.cipher.(cipher.BlockMode); ok { - n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) - if err != nil { - return n, c.out.setErrorLocked(err) - } - m, b = 1, b[1:] - } - } - - n, err := c.writeRecordLocked(recordTypeApplicationData, b) - return n + m, c.out.setErrorLocked(err) -} - -// handleRenegotiation processes a HelloRequest handshake message. -func (c *Conn) handleRenegotiation() error { - if c.vers == VersionTLS13 { - return errors.New("tls: internal error: unexpected renegotiation") - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - helloReq, ok := msg.(*helloRequestMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(helloReq, msg) - } - - if !c.isClient { - return c.sendAlert(alertNoRenegotiation) - } - - switch c.config.Renegotiation { - case RenegotiateNever: - return c.sendAlert(alertNoRenegotiation) - case RenegotiateOnceAsClient: - if c.handshakes > 1 { - return c.sendAlert(alertNoRenegotiation) - } - case RenegotiateFreelyAsClient: - // Ok. - default: - c.sendAlert(alertInternalError) - return errors.New("tls: unknown Renegotiation value") - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - atomic.StoreUint32(&c.handshakeStatus, 0) - if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { - c.handshakes++ - } - return c.handshakeErr -} - -func (c *Conn) HandlePostHandshakeMessage() error { - return c.handlePostHandshakeMessage() -} - -// handlePostHandshakeMessage processes a handshake message arrived after the -// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation. -func (c *Conn) handlePostHandshakeMessage() error { - if c.vers != VersionTLS13 { - return c.handleRenegotiation() - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - c.retryCount++ - if c.retryCount > maxUselessRecords { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many non-advancing records")) - } - - switch msg := msg.(type) { - case *newSessionTicketMsgTLS13: - return c.handleNewSessionTicket(msg) - case *keyUpdateMsg: - return c.handleKeyUpdate(msg) - default: - c.sendAlert(alertUnexpectedMessage) - return fmt.Errorf("tls: received unexpected handshake message of type %T", msg) - } -} - -func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil { - return c.in.setErrorLocked(c.sendAlert(alertInternalError)) - } - - newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, newSecret) - - if keyUpdate.updateRequested { - c.out.Lock() - defer c.out.Unlock() - - msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) - if err != nil { - // Surface the error at the next write. - c.out.setErrorLocked(err) - return nil - } - - newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, newSecret) - } - - return nil -} - -// Read reads data from the connection. -// -// As Read calls Handshake, in order to prevent indefinite blocking a deadline -// must be set for both Read and Write before Read is called when the handshake -// has not yet completed. See SetDeadline, SetReadDeadline, and -// SetWriteDeadline. -func (c *Conn) Read(b []byte) (int, error) { - if err := c.Handshake(); err != nil { - return 0, err - } - if len(b) == 0 { - // Put this after Handshake, in case people were calling - // Read(nil) for the side effect of the Handshake. - return 0, nil - } - - c.in.Lock() - defer c.in.Unlock() - - for c.input.Len() == 0 { - if err := c.readRecord(); err != nil { - return 0, err - } - for c.hand.Len() > 0 { - if err := c.handlePostHandshakeMessage(); err != nil { - return 0, err - } - } - } - - n, _ := c.input.Read(b) - - // If a close-notify alert is waiting, read it so that we can return (n, - // EOF) instead of (n, nil), to signal to the HTTP response reading - // goroutine that the connection is now closed. This eliminates a race - // where the HTTP response reading goroutine would otherwise not observe - // the EOF until its next read, by which time a client goroutine might - // have already tried to reuse the HTTP connection for a new request. - // See https://golang.org/cl/76400046 and https://golang.org/issue/3514 - if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && - recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { - if err := c.readRecord(); err != nil { - return n, err // will be io.EOF on closeNotify - } - } - - return n, nil -} - -// Close closes the connection. -func (c *Conn) Close() error { - // Interlock with Conn.Write above. - var x int32 - for { - x = atomic.LoadInt32(&c.activeCall) - if x&1 != 0 { - return net.ErrClosed - } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { - break - } - } - if x != 0 { - // io.Writer and io.Closer should not be used concurrently. - // If Close is called while a Write is currently in-flight, - // interpret that as a sign that this Close is really just - // being used to break the Write and/or clean up resources and - // avoid sending the alertCloseNotify, which may block - // waiting on handshakeMutex or the c.out mutex. - return c.conn.Close() - } - - var alertErr error - if c.handshakeComplete() { - if err := c.closeNotify(); err != nil { - alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) - } - } - - if err := c.conn.Close(); err != nil { - return err - } - return alertErr -} - -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - if !c.handshakeComplete() { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - c.out.Lock() - defer c.out.Unlock() - - if !c.closeNotifySent { - // Set a Write Deadline to prevent possibly blocking forever. - c.SetWriteDeadline(time.Now().Add(time.Second * 5)) - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - // Any subsequent writes will fail. - c.SetWriteDeadline(time.Now()) - } - return c.closeNotifyErr -} - -// Handshake runs the client or server handshake -// protocol if it has not yet been run. -// -// Most uses of this package need not call Handshake explicitly: the -// first Read or Write will call it automatically. -// -// For control over canceling or setting a timeout on a handshake, use -// HandshakeContext or the Dialer's DialContext method instead. -func (c *Conn) Handshake() error { - return c.HandshakeContext(context.Background()) -} - -// HandshakeContext runs the client or server handshake -// protocol if it has not yet been run. -// -// The provided Context must be non-nil. If the context is canceled before -// the handshake is complete, the handshake is interrupted and an error is returned. -// Once the handshake has completed, cancellation of the context will not affect the -// connection. -// -// Most uses of this package need not call HandshakeContext explicitly: the -// first Read or Write will call it automatically. -func (c *Conn) HandshakeContext(ctx context.Context) error { - // Delegate to unexported method for named return - // without confusing documented signature. - return c.handshakeContext(ctx) -} - -func (c *Conn) handshakeContext(ctx context.Context) (ret error) { - // Fast sync/atomic-based exit if there is no handshake in flight and the - // last one succeeded without an error. Avoids the expensive context setup - // and mutex for most Read and Write calls. - if c.handshakeComplete() { - return nil - } - - handshakeCtx, cancel := context.WithCancel(ctx) - // Note: defer this before starting the "interrupter" goroutine - // so that we can tell the difference between the input being canceled and - // this cancellation. In the former case, we need to close the connection. - defer cancel() - - // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot). - // - // The interrupter goroutine waits for the input context to be done and - // closes the connection if this happens before the function returns. - if ctx.Done() != nil { - done := make(chan struct{}) - interruptRes := make(chan error, 1) - defer func() { - close(done) - if ctxErr := <-interruptRes; ctxErr != nil { - // Return context error to user. - ret = ctxErr - } - }() - go func() { - select { - case <-handshakeCtx.Done(): - // Close the connection, discarding the error - _ = c.conn.Close() - interruptRes <- handshakeCtx.Err() - case <-done: - interruptRes <- nil - } - }() - } - - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if err := c.handshakeErr; err != nil { - return err - } - if c.handshakeComplete() { - return nil - } - - c.in.Lock() - defer c.in.Unlock() - - c.handshakeErr = c.handshakeFn(handshakeCtx) - if c.handshakeErr == nil { - c.handshakes++ - } else { - // If an error occurred during the handshake try to flush the - // alert that might be left in the buffer. - c.flush() - } - - if c.handshakeErr == nil && !c.handshakeComplete() { - c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") - } - if c.handshakeErr != nil && c.handshakeComplete() { - panic("tls: internal error: handshake returned an error but is marked successful") - } - - return c.handshakeErr -} - -// ConnectionState returns basic TLS details about the connection. -func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return c.connectionStateLocked() -} - -// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. -func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return ConnectionStateWith0RTT{ - ConnectionState: c.connectionStateLocked(), - Used0RTT: c.used0RTT, - } -} - -func (c *Conn) connectionStateLocked() ConnectionState { - var state connectionState - state.HandshakeComplete = c.handshakeComplete() - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = true - state.ServerName = c.serverName - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial - } else { - state.ekm = c.ekm - } - return toConnectionState(state) -} - -// OCSPResponse returns the stapled OCSP response from the TLS server, if -// any. (Only valid for client connections.) -func (c *Conn) OCSPResponse() []byte { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - return c.ocspResponse -} - -// VerifyHostname checks that the peer certificate chain is valid for -// connecting to host. If so, it returns nil; if not, it returns an error -// describing the problem. -func (c *Conn) VerifyHostname(host string) error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.isClient { - return errors.New("tls: VerifyHostname called on TLS server connection") - } - if !c.handshakeComplete() { - return errors.New("tls: handshake has not yet been performed") - } - if len(c.verifiedChains) == 0 { - return errors.New("tls: handshake did not verify certificate chain") - } - return c.peerCertificates[0].VerifyHostname(host) -} - -func (c *Conn) handshakeComplete() bool { - return atomic.LoadUint32(&c.handshakeStatus) == 1 -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go deleted file mode 100644 index ac5d0b4c2a6..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client.go +++ /dev/null @@ -1,1111 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "net" - "strings" - "sync/atomic" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -const clientSessionStateVersion = 1 - -type clientHandshakeState struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - suite *cipherSuite - finishedHash finishedHash - masterSecret []byte - session *clientSessionState -} - -func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) { - config := c.config - if len(config.ServerName) == 0 && !config.InsecureSkipVerify { - return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") - } - - nextProtosLength := 0 - for _, proto := range config.NextProtos { - if l := len(proto); l == 0 || l > 255 { - return nil, nil, errors.New("tls: invalid NextProtos value") - } else { - nextProtosLength += 1 + l - } - } - if nextProtosLength > 0xffff { - return nil, nil, errors.New("tls: NextProtos values too large") - } - - var supportedVersions []uint16 - var clientHelloVersion uint16 - if c.extraConfig.usesAlternativeRecordLayer() { - if config.maxSupportedVersion() < VersionTLS13 { - return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") - } - // Only offer TLS 1.3 when QUIC is used. - supportedVersions = []uint16{VersionTLS13} - clientHelloVersion = VersionTLS13 - } else { - supportedVersions = config.supportedVersions() - if len(supportedVersions) == 0 { - return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") - } - clientHelloVersion = config.maxSupportedVersion() - } - - // The version at the beginning of the ClientHello was capped at TLS 1.2 - // for compatibility reasons. The supported_versions extension is used - // to negotiate versions now. See RFC 8446, Section 4.2.1. - if clientHelloVersion > VersionTLS12 { - clientHelloVersion = VersionTLS12 - } - - hello := &clientHelloMsg{ - vers: clientHelloVersion, - compressionMethods: []uint8{compressionNone}, - random: make([]byte, 32), - ocspStapling: true, - scts: true, - serverName: hostnameInSNI(config.ServerName), - supportedCurves: config.curvePreferences(), - supportedPoints: []uint8{pointFormatUncompressed}, - secureRenegotiationSupported: true, - alpnProtocols: config.NextProtos, - supportedVersions: supportedVersions, - } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - configCipherSuites := config.cipherSuites() - hello.cipherSuites = make([]uint16, 0, len(configCipherSuites)) - - for _, suiteId := range preferenceOrder { - suite := mutualCipherSuite(configCipherSuites, suiteId) - if suite == nil { - continue - } - // Don't advertise TLS 1.2-only cipher suites unless - // we're attempting TLS 1.2. - if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { - continue - } - hello.cipherSuites = append(hello.cipherSuites, suiteId) - } - - _, err := io.ReadFull(config.rand(), hello.random) - if err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - - // A random session ID is used to detect when the server accepted a ticket - // and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as - // a compatibility measure (see RFC 8446, Section 4.1.2). - if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { - hello.sessionId = make([]byte, 32) - if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil { - return nil, nil, errors.New("tls: short read from Rand: " + err.Error()) - } - } - - if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms - } - - var params ecdheParameters - if hello.supportedVersions[0] == VersionTLS13 { - var suites []uint16 - for _, suiteID := range configCipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - suites = append(suites, suiteID) - } - } - } - if len(suites) > 0 { - hello.cipherSuites = suites - } else { - if hasAESGCMHardwareSupport { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) - } else { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) - } - } - - curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err = generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} - } - - if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil { - hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello) - } - - return hello, params, nil -} - -func (c *Conn) clientHandshake(ctx context.Context) (err error) { - if c.config == nil { - c.config = fromConfig(defaultConfig()) - } - c.setAlternativeRecordLayer() - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - hello, ecdheParams, err := c.makeClientHello() - if err != nil { - return err - } - c.serverName = hello.serverName - - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) - if cacheKey != "" && session != nil { - var deletedTicket bool - if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { - // don't reuse a session ticket that enabled 0-RTT - c.config.ClientSessionCache.Put(cacheKey, nil) - deletedTicket = true - - if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { - h := suite.hash.New() - h.Write(hello.marshal()) - clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) - c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) - if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - } - } - if !deletedTicket { - defer func() { - // If we got a handshake failure when resuming a session, throw away - // the session ticket. See RFC 5077, Section 3.2. - // - // RFC 8446 makes no mention of dropping tickets on failure, but it - // does require servers to abort on invalid binders, so we need to - // delete tickets to recover from a corrupted PSK. - if err != nil { - c.config.ClientSessionCache.Put(cacheKey, nil) - } - }() - } - } - - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - - if err := c.pickTLSVersion(serverHello); err != nil { - return err - } - - // If we are negotiating a protocol version that's lower than what we - // support, check for the server downgrade canaries. - // See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() - tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12 - tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11 - if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) || - maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox") - } - - if c.vers == VersionTLS13 { - hs := &clientHandshakeStateTLS13{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - ecdheParams: ecdheParams, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - } - - // In TLS 1.3, session tickets are delivered after the handshake. - return hs.handshake() - } - - hs := &clientHandshakeState{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - session: session, - } - - if err := hs.handshake(); err != nil { - return err - } - - // If we had a successful handshake and hs.session is different from - // the one already cached - cache a new one. - if cacheKey != "" && hs.session != nil && session != hs.session { - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) - } - - return nil -} - -// extract the app data saved in the session.nonce, -// and set the session.nonce to the actual nonce value -func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) { - s := cryptobyte.String(session.nonce) - var version uint16 - if !s.ReadUint16(&version) { - return 0, nil, false - } - if version != clientSessionStateVersion { - return 0, nil, false - } - var maxEarlyData uint32 - if !s.ReadUint32(&maxEarlyData) { - return 0, nil, false - } - var appData []byte - if !readUint16LengthPrefixed(&s, &appData) { - return 0, nil, false - } - var nonce []byte - if !readUint16LengthPrefixed(&s, &nonce) { - return 0, nil, false - } - session.nonce = nonce - return maxEarlyData, appData, true -} - -func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *clientSessionState, earlySecret, binderKey []byte) { - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil - } - - hello.ticketSupported = true - - if hello.supportedVersions[0] == VersionTLS13 { - // Require DHE on resumption as it guarantees forward secrecy against - // compromise of the session ticket key. See RFC 8446, Section 4.2.9. - hello.pskModes = []uint8{pskModeDHE} - } - - // Session resumption is not allowed if renegotiating because - // renegotiation is primarily used to allow a client to send a client - // certificate, which would be skipped if session resumption occurred. - if c.handshakes != 0 { - return "", nil, nil, nil - } - - // Try to resume a previously negotiated TLS session, if available. - cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - sess, ok := c.config.ClientSessionCache.Get(cacheKey) - if !ok || sess == nil { - return cacheKey, nil, nil, nil - } - session = fromClientSessionState(sess) - - var appData []byte - var maxEarlyData uint32 - if session.vers == VersionTLS13 { - var ok bool - maxEarlyData, appData, ok = c.decodeSessionState(session) - if !ok { // delete it, if parsing failed - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - } - - // Check that version used for the previous session is still valid. - versOk := false - for _, v := range hello.supportedVersions { - if v == session.vers { - versOk = true - break - } - } - if !versOk { - return cacheKey, nil, nil, nil - } - - // Check that the cached server certificate is not expired, and that it's - // valid for the ServerName. This should be ensured by the cache key, but - // protect the application from a faulty ClientSessionCache implementation. - if !c.config.InsecureSkipVerify { - if len(session.verifiedChains) == 0 { - // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil - } - serverCert := session.serverCertificates[0] - if c.config.time().After(serverCert.NotAfter) { - // Expired certificate, delete the entry. - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil - } - } - - if session.vers != VersionTLS13 { - // In TLS 1.2 the cipher suite must match the resumed session. Ensure we - // are still offering it. - if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil - } - - hello.sessionTicket = session.sessionTicket - return - } - - // Check that the session ticket is not expired. - if c.config.time().After(session.useBy) { - c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil - } - - // In TLS 1.3 the KDF hash must match the resumed session. Ensure we - // offer at least one cipher suite with that hash. - cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) - if cipherSuite == nil { - return cacheKey, nil, nil, nil - } - cipherSuiteOk := false - for _, offeredID := range hello.cipherSuites { - offeredSuite := cipherSuiteTLS13ByID(offeredID) - if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return cacheKey, nil, nil, nil - } - - // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. - ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond) - identity := pskIdentity{ - label: session.sessionTicket, - obfuscatedTicketAge: ticketAge + session.ageAdd, - } - hello.pskIdentities = []pskIdentity{identity} - hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())} - - // Compute the PSK binders. See RFC 8446, Section 4.2.11.2. - psk := cipherSuite.expandLabel(session.masterSecret, "resumption", - session.nonce, cipherSuite.hash.Size()) - earlySecret = cipherSuite.extract(psk, nil) - binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) - if c.extraConfig != nil { - hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 - } - transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) - pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) - - if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { - c.extraConfig.SetAppDataFromSessionState(appData) - } - return -} - -func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error { - peerVersion := serverHello.vers - if serverHello.supportedVersion != 0 { - peerVersion = serverHello.supportedVersion - } - - vers, ok := c.config.mutualVersion([]uint16{peerVersion}) - if !ok { - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) - } - - c.vers = vers - c.haveVers = true - c.in.version = vers - c.out.version = vers - - return nil -} - -// Does the handshake, either a full one or resumes old session. Requires hs.c, -// hs.hello, hs.serverHello, and, optionally, hs.session to be set. -func (hs *clientHandshakeState) handshake() error { - c := hs.c - - isResume, err := hs.processServerHello() - if err != nil { - return err - } - - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - - // No signatures of the handshake are needed in a resumption. - // Otherwise, in a full handshake, if we don't have any certificates - // configured then we will never send a CertificateVerify message and - // thus no signatures are needed in that case either. - if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { - hs.finishedHash.discardHandshakeBuffer() - } - - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) - - c.buffering = true - c.didResume = isResume - if isResume { - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = false - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } else { - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendFinished(c.clientFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = true - if err := hs.readSessionTicket(); err != nil { - return err - } - if err := hs.readFinished(c.serverFinished[:]); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -func (hs *clientHandshakeState) pickCipherSuite() error { - if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { - hs.c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") - } - - hs.c.cipherSuite = hs.suite.id - return nil -} - -func (hs *clientHandshakeState) doFullHandshake() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - certMsg, ok := msg.(*certificateMsg) - if !ok || len(certMsg.certificates) == 0 { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.finishedHash.Write(certMsg.marshal()) - - msg, err = c.readHandshake() - if err != nil { - return err - } - - cs, ok := msg.(*certificateStatusMsg) - if ok { - // RFC4366 on Certificate Status Request: - // The server MAY return a "certificate_status" message. - - if !hs.serverHello.ocspStapling { - // If a server returns a "CertificateStatus" message, then the - // server MUST have included an extension of type "status_request" - // with empty "extension_data" in the extended server hello. - - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received unexpected CertificateStatus message") - } - hs.finishedHash.Write(cs.marshal()) - - c.ocspResponse = cs.response - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - if c.handshakes == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - - keyAgreement := hs.suite.ka(c.vers) - - skx, ok := msg.(*serverKeyExchangeMsg) - if ok { - hs.finishedHash.Write(skx.marshal()) - err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) - if err != nil { - c.sendAlert(alertUnexpectedMessage) - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - var chainToSend *Certificate - var certRequested bool - certReq, ok := msg.(*certificateRequestMsg) - if ok { - certRequested = true - hs.finishedHash.Write(certReq.marshal()) - - cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) - if chainToSend, err = c.getClientCertificate(cri); err != nil { - c.sendAlert(alertInternalError) - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - shd, ok := msg.(*serverHelloDoneMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(shd, msg) - } - hs.finishedHash.Write(shd.marshal()) - - // If the server requested a certificate then we have to send a - // Certificate message, even if it's empty because we don't have a - // certificate to send. - if certRequested { - certMsg = new(certificateMsg) - certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - } - - preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { - return err - } - } - - if chainToSend != nil && len(chainToSend.Certificate) > 0 { - certVerify := &certificateVerifyMsg{} - - key, ok := chainToSend.PrivateKey.(crypto.Signer) - if !ok { - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - certVerify.hasSignatureAlgorithm = true - certVerify.signatureAlgorithm = signatureAlgorithm - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { - return err - } - } - - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to write to key log: " + err.Error()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *clientHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - if hs.suite.cipher != nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) - c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) - return nil -} - -func (hs *clientHandshakeState) serverResumedSession() bool { - // If the server responded with the same sessionId then it means the - // sessionTicket is being used to resume a TLS session. - return hs.session != nil && hs.hello.sessionId != nil && - bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) -} - -func (hs *clientHandshakeState) processServerHello() (bool, error) { - c := hs.c - - if err := hs.pickCipherSuite(); err != nil { - return false, err - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertUnexpectedMessage) - return false, errors.New("tls: server selected unsupported compression format") - } - - if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported { - c.secureRenegotiation = true - if len(hs.serverHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: initial handshake had non-empty renegotiation extension") - } - } - - if c.handshakes > 0 && c.secureRenegotiation { - var expectedSecureRenegotiation [24]byte - copy(expectedSecureRenegotiation[:], c.clientFinished[:]) - copy(expectedSecureRenegotiation[12:], c.serverFinished[:]) - if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: incorrect renegotiation extension contents") - } - } - - if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil { - c.sendAlert(alertUnsupportedExtension) - return false, err - } - c.clientProtocol = hs.serverHello.alpnProtocol - - c.scts = hs.serverHello.scts - - if !hs.serverResumedSession() { - return false, nil - } - - if hs.session.vers != c.vers { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different version") - } - - if hs.session.cipherSuite != hs.suite.id { - c.sendAlert(alertHandshakeFailure) - return false, errors.New("tls: server resumed a session with a different cipher suite") - } - - // Restore masterSecret, peerCerts, and ocspResponse from previous state - hs.masterSecret = hs.session.masterSecret - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - // Let the ServerHello SCTs override the session SCTs from the original - // connection, if any are provided - if len(c.scts) == 0 && len(hs.session.scts) != 0 { - c.scts = hs.session.scts - } - - return true, nil -} - -// checkALPN ensure that the server's choice of ALPN protocol is compatible with -// the protocols that we advertised in the Client Hello. -func checkALPN(clientProtos []string, serverProto string) error { - if serverProto == "" { - return nil - } - if len(clientProtos) == 0 { - return errors.New("tls: server advertised unrequested ALPN extension") - } - for _, proto := range clientProtos { - if proto == serverProto { - return nil - } - } - return errors.New("tls: server selected unadvertised ALPN protocol") -} - -func (hs *clientHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - serverFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverFinished, msg) - } - - verify := hs.finishedHash.serverSum(hs.masterSecret) - if len(verify) != len(serverFinished.verifyData) || - subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server's Finished message was incorrect") - } - hs.finishedHash.Write(serverFinished.marshal()) - copy(out, verify) - return nil -} - -func (hs *clientHandshakeState) readSessionTicket() error { - if !hs.serverHello.ticketSupported { - return nil - } - - c := hs.c - msg, err := c.readHandshake() - if err != nil { - return err - } - sessionTicketMsg, ok := msg.(*newSessionTicketMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(sessionTicketMsg, msg) - } - hs.finishedHash.Write(sessionTicketMsg.marshal()) - - hs.session = &clientSessionState{ - sessionTicket: sessionTicketMsg.ticket, - vers: c.vers, - cipherSuite: hs.suite.id, - masterSecret: hs.masterSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - return nil -} - -func (hs *clientHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - copy(out, finished.verifyData) - return nil -} - -// verifyServerCertificate parses and verifies the provided chain, setting -// c.verifiedChains and c.peerCertificates or sending the appropriate alert. -func (c *Conn) verifyServerCertificate(certificates [][]byte) error { - certs := make([]*x509.Certificate, len(certificates)) - for i, asn1Data := range certificates { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse certificate from server: " + err.Error()) - } - certs[i] = cert - } - - if !c.config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: c.config.RootCAs, - CurrentTime: c.config.time(), - DNSName: c.config.ServerName, - Intermediates: x509.NewCertPool(), - } - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - var err error - c.verifiedChains, err = certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - switch certs[0].PublicKey.(type) { - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - break - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) - } - - c.peerCertificates = certs - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS -// <= 1.2 CertificateRequest, making an effort to fill in missing information. -func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { - cri := &certificateRequestInfo{ - AcceptableCAs: certReq.certificateAuthorities, - Version: vers, - ctx: ctx, - } - - var rsaAvail, ecAvail bool - for _, certType := range certReq.certificateTypes { - switch certType { - case certTypeRSASign: - rsaAvail = true - case certTypeECDSASign: - ecAvail = true - } - } - - if !certReq.hasSignatureAlgorithm { - // Prior to TLS 1.2, signature schemes did not exist. In this case we - // make up a list based on the acceptable certificate types, to help - // GetClientCertificate and SupportsCertificate select the right certificate. - // The hash part of the SignatureScheme is a lie here, because - // TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA. - switch { - case rsaAvail && ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case rsaAvail: - cri.SignatureSchemes = []SignatureScheme{ - PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1, - } - case ecAvail: - cri.SignatureSchemes = []SignatureScheme{ - ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, - } - } - return toCertificateRequestInfo(cri) - } - - // Filter the signature schemes based on the certificate types. - // See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated"). - cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms)) - for _, sigScheme := range certReq.supportedSignatureAlgorithms { - sigType, _, err := typeAndHashFromSignatureScheme(sigScheme) - if err != nil { - continue - } - switch sigType { - case signatureECDSA, signatureEd25519: - if ecAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - case signatureRSAPSS, signaturePKCS1v15: - if rsaAvail { - cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme) - } - } - } - - return toCertificateRequestInfo(cri) -} - -func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) { - if c.config.GetClientCertificate != nil { - return c.config.GetClientCertificate(cri) - } - - for _, chain := range c.config.Certificates { - if err := cri.SupportsCertificate(&chain); err != nil { - continue - } - return &chain, nil - } - - // No acceptable certificate found. Don't send a certificate. - return new(Certificate), nil -} - -const clientSessionCacheKeyPrefix = "qtls-" - -// clientSessionCacheKey returns a key used to cache sessionTickets that could -// be used to resume previously negotiated TLS sessions with a server. -func clientSessionCacheKey(serverAddr net.Addr, config *config) string { - if len(config.ServerName) > 0 { - return clientSessionCacheKeyPrefix + config.ServerName - } - return clientSessionCacheKeyPrefix + serverAddr.String() -} - -// hostnameInSNI converts name into an appropriate hostname for SNI. -// Literal IP addresses and absolute FQDNs are not permitted as SNI values. -// See RFC 6066, Section 3. -func hostnameInSNI(name string) string { - host := name - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] - } - if i := strings.LastIndex(host, "%"); i > 0 { - host = host[:i] - } - if net.ParseIP(host) != nil { - return "" - } - for len(name) > 0 && name[len(name)-1] == '.' { - name = name[:len(name)-1] - } - return name -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go deleted file mode 100644 index 0de59fc1e16..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_client_tls13.go +++ /dev/null @@ -1,732 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/hmac" - "crypto/rsa" - "encoding/binary" - "errors" - "hash" - "sync/atomic" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - ecdheParams ecdheParameters - - session *clientSessionState - earlySecret []byte - binderKey []byte - - certReq *certificateRequestMsgTLS13 - usingPSK bool - sentDummyCCS bool - suite *cipherSuiteTLS13 - transcript hash.Hash - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 -} - -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, -// optionally, hs.session, hs.earlySecret and hs.binderKey to be set. -func (hs *clientHandshakeStateTLS13) handshake() error { - c := hs.c - - // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, - // sections 4.1.2 and 4.1.3. - if c.handshakes > 0 { - c.sendAlert(alertProtocolVersion) - return errors.New("tls: server selected TLS 1.3 in a renegotiation") - } - - // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { - return c.sendAlert(alertInternalError) - } - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.processHelloRetryRequest(); err != nil { - return err - } - } - - hs.transcript.Write(hs.serverHello.marshal()) - - c.buffering = true - if err := hs.processServerHello(); err != nil { - return err - } - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - if err := hs.establishHandshakeKeys(); err != nil { - return err - } - if err := hs.readServerParameters(); err != nil { - return err - } - if err := hs.readServerCertificate(); err != nil { - return err - } - if err := hs.readServerFinished(); err != nil { - return err - } - if err := hs.sendClientCertificate(); err != nil { - return err - } - if err := hs.sendClientFinished(); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -// checkServerHelloOrHRR does validity checks that apply to both ServerHello and -// HelloRetryRequest messages. It sets hs.suite. -func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error { - c := hs.c - - if hs.serverHello.supportedVersion == 0 { - c.sendAlert(alertMissingExtension) - return errors.New("tls: server selected TLS 1.3 using the legacy version field") - } - - if hs.serverHello.supportedVersion != VersionTLS13 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid version after a HelloRetryRequest") - } - - if hs.serverHello.vers != VersionTLS12 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an incorrect legacy version") - } - - if hs.serverHello.ocspStapling || - hs.serverHello.ticketSupported || - hs.serverHello.secureRenegotiationSupported || - len(hs.serverHello.secureRenegotiation) != 0 || - len(hs.serverHello.alpnProtocol) != 0 || - len(hs.serverHello.scts) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3") - } - - if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not echo the legacy session ID") - } - - if hs.serverHello.compressionMethod != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported compression format") - } - - selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite) - if hs.suite != nil && selectedSuite != hs.suite { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server changed cipher suite after a HelloRetryRequest") - } - if selectedSuite == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server chose an unconfigured cipher suite") - } - hs.suite = selectedSuite - c.cipherSuite = hs.suite.id - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and -// resends hs.hello, and reads the new ServerHello into hs.serverHello. -func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. (The idea is that the server might offload transcript - // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) - - // The only HelloRetryRequest extensions we support are key_share and - // cookie, and clients must abort the handshake if the HRR would not result - // in any change in the ClientHello. - if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest message") - } - - if hs.serverHello.cookie != nil { - hs.hello.cookie = hs.serverHello.cookie - } - - if hs.serverHello.serverShare.group != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received malformed key_share extension") - } - - // If the server sent a key_share extension selecting a group, ensure it's - // a group we advertised but did not send a key share for, and send a key - // share for it this time. - if curveID := hs.serverHello.selectedGroup; curveID != 0 { - curveOK := false - for _, id := range hs.hello.supportedCurves { - if id == curveID { - curveOK = true - break - } - } - if !curveOK { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - if hs.ecdheParams.CurveID() == curveID { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") - } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.ecdheParams = params - hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} - } - - hs.hello.raw = nil - if len(hs.hello.pskIdentities) > 0 { - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash == hs.suite.hash { - // Update binders and obfuscated_ticket_age. - ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond) - hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd - - transcript := hs.suite.hash.New() - transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) - pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) - } else { - // Server selected a cipher suite incompatible with the PSK. - hs.hello.pskIdentities = nil - hs.hello.pskBinders = nil - } - } - - if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { - c.extraConfig.Rejected0RTT() - } - hs.hello.earlyData = false // disable 0-RTT - - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - hs.serverHello = serverHello - - if err := hs.checkServerHelloOrHRR(); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) processServerHello() error { - c := hs.c - - if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: server sent two HelloRetryRequest messages") - } - - if len(hs.serverHello.cookie) != 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: server sent a cookie in a normal ServerHello") - } - - if hs.serverHello.selectedGroup != 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: malformed key_share extension") - } - - if hs.serverHello.serverShare.group == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server did not send a key share") - } - if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected unsupported group") - } - - if !hs.serverHello.selectedIdentityPresent { - return nil - } - - if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK") - } - - if len(hs.hello.pskIdentities) != 1 || hs.session == nil { - return c.sendAlert(alertInternalError) - } - pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) - if pskSuite == nil { - return c.sendAlert(alertInternalError) - } - if pskSuite.hash != hs.suite.hash { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: server selected an invalid PSK and cipher suite pair") - } - - hs.usingPSK = true - c.didResume = true - c.peerCertificates = hs.session.serverCertificates - c.verifiedChains = hs.session.verifiedChains - c.ocspResponse = hs.session.ocspResponse - c.scts = hs.session.scts - return nil -} - -func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { - c := hs.c - - sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) - if sharedKey == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") - } - - earlySecret := hs.earlySecret - if !hs.usingPSK { - earlySecret = hs.suite.extract(nil, nil) - } - handshakeSecret := hs.suite.extract(sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret) - c.out.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(handshakeSecret, "derived", nil)) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerParameters() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(encryptedExtensions, msg) - } - // Notify the caller if 0-RTT was rejected. - if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil { - c.extraConfig.Rejected0RTT() - } - c.used0RTT = encryptedExtensions.earlyData - if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { - hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) - } - hs.transcript.Write(encryptedExtensions.marshal()) - - if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { - c.sendAlert(alertUnsupportedExtension) - return err - } - c.clientProtocol = encryptedExtensions.alpnProtocol - - if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection { - if len(encryptedExtensions.alpnProtocol) == 0 { - // the server didn't select an ALPN - c.sendAlert(alertNoApplicationProtocol) - return errors.New("ALPN negotiation failed. Server didn't offer any protocols") - } - } - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerCertificate() error { - c := hs.c - - // Either a PSK or a certificate is always used, but not both. - // See RFC 8446, Section 4.1.1. - if hs.usingPSK { - // Make sure the connection is still being verified whether or not this - // is a resumption. Resumptions currently don't reverify certificates so - // they don't call verifyServerCertificate. See Issue 31641. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - certReq, ok := msg.(*certificateRequestMsgTLS13) - if ok { - hs.transcript.Write(certReq.marshal()) - - hs.certReq = certReq - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - if len(certMsg.certificate.Certificate) == 0 { - c.sendAlert(alertDecodeError) - return errors.New("tls: received empty certificates message") - } - hs.transcript.Write(certMsg.marshal()) - - c.scts = certMsg.certificate.SignedCertificateTimestamps - c.ocspResponse = certMsg.certificate.OCSPStaple - - if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil { - return err - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - - hs.transcript.Write(certVerify.marshal()) - - return nil -} - -func (hs *clientHandshakeStateTLS13) readServerFinished() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - if !hmac.Equal(expectedMAC, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid server finished hash") - } - - hs.transcript.Write(finished.marshal()) - - // Derive secrets that take context through the server Finished. - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionApplication, hs.suite, serverSecret) - c.in.setTrafficSecret(hs.suite, serverSecret) - - err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { - c := hs.c - - if hs.certReq == nil { - return nil - } - - cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{ - AcceptableCAs: hs.certReq.certificateAuthorities, - SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, - Version: c.vers, - ctx: hs.ctx, - })) - if err != nil { - return err - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *cert - certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - // If we sent an empty certificate message, skip the CertificateVerify. - if len(cert.Certificate) == 0 { - return nil - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - - certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms) - if err != nil { - // getClientCertificate returned a certificate incompatible with the - // CertificateRequestInfo supported signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *clientHandshakeStateTLS13) sendClientFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) - c.out.setTrafficSecret(hs.suite, hs.trafficSecret) - - if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - } - - return nil -} - -func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error { - if !c.isClient { - c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: received new session ticket from a client") - } - - if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return nil - } - - // See RFC 8446, Section 4.6.1. - if msg.lifetime == 0 { - return nil - } - lifetime := time.Duration(msg.lifetime) * time.Second - if lifetime > maxSessionTicketLifetime { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: received a session ticket with invalid lifetime") - } - - cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite) - if cipherSuite == nil || c.resumptionSecret == nil { - return c.sendAlert(alertInternalError) - } - - // We need to save the max_early_data_size that the server sent us, in order - // to decide if we're going to try 0-RTT with this ticket. - // However, at the same time, the qtls.ClientSessionTicket needs to be equal to - // the tls.ClientSessionTicket, so we can't just add a new field to the struct. - // We therefore abuse the nonce field (which is a byte slice) - nonceWithEarlyData := make([]byte, len(msg.nonce)+4) - binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData) - copy(nonceWithEarlyData[4:], msg.nonce) - - var appData []byte - if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil { - appData = c.extraConfig.GetAppDataForSessionState() - } - var b cryptobyte.Builder - b.AddUint16(clientSessionStateVersion) // revision - b.AddUint32(msg.maxEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(appData) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(msg.nonce) - }) - - // Save the resumption_master_secret and nonce instead of deriving the PSK - // to do the least amount of work on NewSessionTicket messages before we - // know if the ticket will be used. Forward secrecy of resumed connections - // is guaranteed by the requirement for pskModeDHE. - session := &clientSessionState{ - sessionTicket: msg.label, - vers: c.vers, - cipherSuite: c.cipherSuite, - masterSecret: c.resumptionSecret, - serverCertificates: c.peerCertificates, - verifiedChains: c.verifiedChains, - receivedAt: c.config.time(), - nonce: b.BytesOrPanic(), - useBy: c.config.time().Add(lifetime), - ageAdd: msg.ageAdd, - ocspResponse: c.ocspResponse, - scts: c.scts, - } - - cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) - c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session)) - - return nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go deleted file mode 100644 index 1ab75762637..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_messages.go +++ /dev/null @@ -1,1832 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "fmt" - "strings" - - "golang.org/x/crypto/cryptobyte" -) - -// The marshalingFunction type is an adapter to allow the use of ordinary -// functions as cryptobyte.MarshalingValue. -type marshalingFunction func(b *cryptobyte.Builder) error - -func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { - return f(b) -} - -// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If -// the length of the sequence is not the value specified, it produces an error. -func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { - b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { - if len(v) != n { - return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) - } - b.AddBytes(v) - return nil - })) -} - -// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder. -func addUint64(b *cryptobyte.Builder, v uint64) { - b.AddUint32(uint32(v >> 32)) - b.AddUint32(uint32(v)) -} - -// readUint64 decodes a big-endian, 64-bit value into out and advances over it. -// It reports whether the read was successful. -func readUint64(s *cryptobyte.String, out *uint64) bool { - var hi, lo uint32 - if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) { - return false - } - *out = uint64(hi)<<32 | uint64(lo) - return true -} - -// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) -} - -// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a -// []byte instead of a cryptobyte.String. -func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { - return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) -} - -type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - serverName string - ocspStapling bool - supportedCurves []CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocols []string - scts bool - supportedVersions []uint16 - cookie []byte - keyShares []keyShare - earlyData bool - pskModes []uint8 - pskIdentities []pskIdentity - pskBinders [][]byte - additionalExtensions []Extension -} - -func (m *clientHelloMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeClientHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, suite := range m.cipherSuites { - b.AddUint16(suite) - } - }) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.compressionMethods) - }) - - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -// marshalWithoutBinders returns the ClientHello through the -// PreSharedKeyExtension.identities field, according to RFC 8446, Section -// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { - bindersLen := 2 // uint16 length prefix - for _, binder := range m.pskBinders { - bindersLen += 1 // uint8 length prefix - bindersLen += len(binder) - } - - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] -} - -// updateBinders updates the m.pskBinders field, if necessary updating the -// cached marshaled representation. The supplied binders must have the same -// length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { - if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") - } - for i := range m.pskBinders { - if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") - } - } - m.pskBinders = pskBinders - if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) - // TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported. - b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders]) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - if len(b.BytesOrPanic()) != len(m.raw) { - panic("tls: internal error: failed to update binders") - } - } -} - -func (m *clientHelloMsg) unmarshal(data []byte) bool { - *m = clientHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) { - return false - } - - var cipherSuites cryptobyte.String - if !s.ReadUint16LengthPrefixed(&cipherSuites) { - return false - } - m.cipherSuites = []uint16{} - m.secureRenegotiationSupported = false - for !cipherSuites.Empty() { - var suite uint16 - if !cipherSuites.ReadUint16(&suite) { - return false - } - if suite == scsvRenegotiation { - m.secureRenegotiationSupported = true - } - m.cipherSuites = append(m.cipherSuites, suite) - } - - if !readUint8LengthPrefixed(&s, &m.compressionMethods) { - return false - } - - if s.Empty() { - // ClientHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var ext uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&ext) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch ext { - case extensionServerName: - // RFC 6066, Section 3 - var nameList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { - return false - } - for !nameList.Empty() { - var nameType uint8 - var serverName cryptobyte.String - if !nameList.ReadUint8(&nameType) || - !nameList.ReadUint16LengthPrefixed(&serverName) || - serverName.Empty() { - return false - } - if nameType != 0 { - continue - } - if len(m.serverName) != 0 { - // Multiple names of the same name_type are prohibited. - return false - } - m.serverName = string(serverName) - // An SNI value may not include a trailing dot. - if strings.HasSuffix(m.serverName, ".") { - return false - } - } - case extensionStatusRequest: - // RFC 4366, Section 3.6 - var statusType uint8 - var ignored cryptobyte.String - if !extData.ReadUint8(&statusType) || - !extData.ReadUint16LengthPrefixed(&ignored) || - !extData.ReadUint16LengthPrefixed(&ignored) { - return false - } - m.ocspStapling = statusType == statusTypeOCSP - case extensionSupportedCurves: - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - var curves cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { - return false - } - for !curves.Empty() { - var curve uint16 - if !curves.ReadUint16(&curve) { - return false - } - m.supportedCurves = append(m.supportedCurves, CurveID(curve)) - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - case extensionSessionTicket: - // RFC 5077, Section 3.2 - m.ticketSupported = true - extData.ReadBytes(&m.sessionTicket, len(extData)) - case extensionSignatureAlgorithms: - // RFC 5246, Section 7.4.1.4.1 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - // RFC 8446, Section 4.2.3 - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionRenegotiationInfo: - // RFC 5746, Section 3.2 - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - // RFC 7301, Section 3.1 - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - for !protoList.Empty() { - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { - return false - } - m.alpnProtocols = append(m.alpnProtocols, string(proto)) - } - case extensionSCT: - // RFC 6962, Section 3.3.1 - m.scts = true - case extensionSupportedVersions: - // RFC 8446, Section 4.2.1 - var versList cryptobyte.String - if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { - return false - } - for !versList.Empty() { - var vers uint16 - if !versList.ReadUint16(&vers) { - return false - } - m.supportedVersions = append(m.supportedVersions, vers) - } - case extensionCookie: - // RFC 8446, Section 4.2.2 - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // RFC 8446, Section 4.2.8 - var clientShares cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&clientShares) { - return false - } - for !clientShares.Empty() { - var ks keyShare - if !clientShares.ReadUint16((*uint16)(&ks.group)) || - !readUint16LengthPrefixed(&clientShares, &ks.data) || - len(ks.data) == 0 { - return false - } - m.keyShares = append(m.keyShares, ks) - } - case extensionEarlyData: - // RFC 8446, Section 4.2.10 - m.earlyData = true - case extensionPSKModes: - // RFC 8446, Section 4.2.9 - if !readUint8LengthPrefixed(&extData, &m.pskModes) { - return false - } - case extensionPreSharedKey: - // RFC 8446, Section 4.2.11 - if !extensions.Empty() { - return false // pre_shared_key must be the last extension - } - var identities cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { - return false - } - for !identities.Empty() { - var psk pskIdentity - if !readUint16LengthPrefixed(&identities, &psk.label) || - !identities.ReadUint32(&psk.obfuscatedTicketAge) || - len(psk.label) == 0 { - return false - } - m.pskIdentities = append(m.pskIdentities, psk) - } - var binders cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { - return false - } - for !binders.Empty() { - var binder []byte - if !readUint8LengthPrefixed(&binders, &binder) || - len(binder) == 0 { - return false - } - m.pskBinders = append(m.pskBinders, binder) - } - default: - m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type serverHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuite uint16 - compressionMethod uint8 - ocspStapling bool - ticketSupported bool - secureRenegotiationSupported bool - secureRenegotiation []byte - alpnProtocol string - scts [][]byte - supportedVersion uint16 - serverShare keyShare - selectedIdentityPresent bool - selectedIdentity uint16 - supportedPoints []uint8 - - // HelloRetryRequest extensions - cookie []byte - selectedGroup CurveID -} - -func (m *serverHelloMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeServerHello) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.vers) - addBytesWithLength(b, m.random, 32) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionId) - }) - b.AddUint16(m.cipherSuite) - b.AddUint8(m.compressionMethod) - - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *serverHelloMsg) unmarshal(data []byte) bool { - *m = serverHelloMsg{raw: data} - s := cryptobyte.String(data) - - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || - !readUint8LengthPrefixed(&s, &m.sessionId) || - !s.ReadUint16(&m.cipherSuite) || - !s.ReadUint8(&m.compressionMethod) { - return false - } - - if s.Empty() { - // ServerHello is optionally followed by extension data - return true - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSessionTicket: - m.ticketSupported = true - case extensionRenegotiationInfo: - if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { - return false - } - m.secureRenegotiationSupported = true - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - m.scts = append(m.scts, sct) - } - case extensionSupportedVersions: - if !extData.ReadUint16(&m.supportedVersion) { - return false - } - case extensionCookie: - if !readUint16LengthPrefixed(&extData, &m.cookie) || - len(m.cookie) == 0 { - return false - } - case extensionKeyShare: - // This extension has different formats in SH and HRR, accept either - // and let the handshake logic decide. See RFC 8446, Section 4.2.8. - if len(extData) == 2 { - if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { - return false - } - } else { - if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || - !readUint16LengthPrefixed(&extData, &m.serverShare.data) { - return false - } - } - case extensionPreSharedKey: - m.selectedIdentityPresent = true - if !extData.ReadUint16(&m.selectedIdentity) { - return false - } - case extensionSupportedPoints: - // RFC 4492, Section 5.1.2 - if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || - len(m.supportedPoints) == 0 { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type encryptedExtensionsMsg struct { - raw []byte - alpnProtocol string - earlyData bool - - additionalExtensions []Extension -} - -func (m *encryptedExtensionsMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeEncryptedExtensions) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { - *m = encryptedExtensionsMsg{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return false - } - - for !extensions.Empty() { - var ext uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&ext) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch ext { - case extensionALPN: - var protoList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { - return false - } - var proto cryptobyte.String - if !protoList.ReadUint8LengthPrefixed(&proto) || - proto.Empty() || !protoList.Empty() { - return false - } - m.alpnProtocol = string(proto) - case extensionEarlyData: - m.earlyData = true - default: - m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type endOfEarlyDataMsg struct{} - -func (m *endOfEarlyDataMsg) marshal() []byte { - x := make([]byte, 4) - x[0] = typeEndOfEarlyData - return x -} - -func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type keyUpdateMsg struct { - raw []byte - updateRequested bool -} - -func (m *keyUpdateMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeKeyUpdate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.updateRequested { - b.AddUint8(1) - } else { - b.AddUint8(0) - } - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *keyUpdateMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var updateRequested uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&updateRequested) || !s.Empty() { - return false - } - switch updateRequested { - case 0: - m.updateRequested = false - case 1: - m.updateRequested = true - default: - return false - } - return true -} - -type newSessionTicketMsgTLS13 struct { - raw []byte - lifetime uint32 - ageAdd uint32 - nonce []byte - label []byte - maxEarlyData uint32 -} - -func (m *newSessionTicketMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeNewSessionTicket) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.lifetime) - b.AddUint32(m.ageAdd) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.nonce) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.label) - }) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.maxEarlyData > 0 { - b.AddUint16(extensionEarlyData) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint32(m.maxEarlyData) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { - *m = newSessionTicketMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint32(&m.lifetime) || - !s.ReadUint32(&m.ageAdd) || - !readUint8LengthPrefixed(&s, &m.nonce) || - !readUint16LengthPrefixed(&s, &m.label) || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionEarlyData: - if !extData.ReadUint32(&m.maxEarlyData) { - return false - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateRequestMsgTLS13 struct { - raw []byte - ocspStapling bool - scts bool - supportedSignatureAlgorithms []SignatureScheme - supportedSignatureAlgorithmsCert []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateRequest) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - // certificate_request_context (SHALL be zero length unless used for - // post-handshake authentication) - b.AddUint8(0) - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.scts { - // RFC 8446, Section 4.4.2.1 makes no mention of - // signed_certificate_timestamp in CertificateRequest, but - // "Extensions in the Certificate message from the client MUST - // correspond to extensions in the CertificateRequest message - // from the server." and it appears in the table in Section 4.2. - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedSignatureAlgorithms) > 0 { - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.certificateAuthorities) > 0 { - b.AddUint16(extensionCertificateAuthorities) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ca := range m.certificateAuthorities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ca) - }) - } - }) - }) - } - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { - *m = certificateRequestMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context, extensions cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !s.ReadUint16LengthPrefixed(&extensions) || - !s.Empty() { - return false - } - - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - - switch extension { - case extensionStatusRequest: - m.ocspStapling = true - case extensionSCT: - m.scts = true - case extensionSignatureAlgorithms: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithms = append( - m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) - } - case extensionSignatureAlgorithmsCert: - var sigAndAlgs cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { - return false - } - for !sigAndAlgs.Empty() { - var sigAndAlg uint16 - if !sigAndAlgs.ReadUint16(&sigAndAlg) { - return false - } - m.supportedSignatureAlgorithmsCert = append( - m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) - } - case extensionCertificateAuthorities: - var auths cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() { - return false - } - for !auths.Empty() { - var ca []byte - if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 { - return false - } - m.certificateAuthorities = append(m.certificateAuthorities, ca) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - - return true -} - -type certificateMsg struct { - raw []byte - certificates [][]byte -} - -func (m *certificateMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - var i int - for _, slice := range m.certificates { - i += len(slice) - } - - length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) - x[0] = typeCertificate - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - certificateOctets := length - 3 - x[4] = uint8(certificateOctets >> 16) - x[5] = uint8(certificateOctets >> 8) - x[6] = uint8(certificateOctets) - - y := x[7:] - for _, slice := range m.certificates { - y[0] = uint8(len(slice) >> 16) - y[1] = uint8(len(slice) >> 8) - y[2] = uint8(len(slice)) - copy(y[3:], slice) - y = y[3+len(slice):] - } - - m.raw = x - return -} - -func (m *certificateMsg) unmarshal(data []byte) bool { - if len(data) < 7 { - return false - } - - m.raw = data - certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) - if uint32(len(data)) != certsLen+7 { - return false - } - - numCerts := 0 - d := data[7:] - for certsLen > 0 { - if len(d) < 4 { - return false - } - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - if uint32(len(d)) < 3+certLen { - return false - } - d = d[3+certLen:] - certsLen -= 3 + certLen - numCerts++ - } - - m.certificates = make([][]byte, numCerts) - d = data[7:] - for i := 0; i < numCerts; i++ { - certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) - m.certificates[i] = d[3 : 3+certLen] - d = d[3+certLen:] - } - - return true -} - -type certificateMsgTLS13 struct { - raw []byte - certificate Certificate - ocspStapling bool - scts bool -} - -func (m *certificateMsgTLS13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificate) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // certificate_request_context - - certificate := m.certificate - if !m.ocspStapling { - certificate.OCSPStaple = nil - } - if !m.scts { - certificate.SignedCertificateTimestamps = nil - } - marshalCertificate(b, certificate) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for i, cert := range certificate.Certificate { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if i > 0 { - // This library only supports OCSP and SCT for leaf certificates. - return - } - if certificate.OCSPStaple != nil { - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(certificate.OCSPStaple) - }) - }) - } - if certificate.SignedCertificateTimestamps != nil { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range certificate.SignedCertificateTimestamps { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - }) - } - }) -} - -func (m *certificateMsgTLS13) unmarshal(data []byte) bool { - *m = certificateMsgTLS13{raw: data} - s := cryptobyte.String(data) - - var context cryptobyte.String - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || - !unmarshalCertificate(&s, &m.certificate) || - !s.Empty() { - return false - } - - m.scts = m.certificate.SignedCertificateTimestamps != nil - m.ocspStapling = m.certificate.OCSPStaple != nil - - return true -} - -func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool { - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - var extensions cryptobyte.String - if !readUint24LengthPrefixed(&certList, &cert) || - !certList.ReadUint16LengthPrefixed(&extensions) { - return false - } - certificate.Certificate = append(certificate.Certificate, cert) - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return false - } - if len(certificate.Certificate) > 1 { - // This library only supports OCSP and SCT for leaf certificates. - continue - } - - switch extension { - case extensionStatusRequest: - var statusType uint8 - if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) || - len(certificate.OCSPStaple) == 0 { - return false - } - case extensionSCT: - var sctList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { - return false - } - for !sctList.Empty() { - var sct []byte - if !readUint16LengthPrefixed(&sctList, &sct) || - len(sct) == 0 { - return false - } - certificate.SignedCertificateTimestamps = append( - certificate.SignedCertificateTimestamps, sct) - } - default: - // Ignore unknown extensions. - continue - } - - if !extData.Empty() { - return false - } - } - } - return true -} - -type serverKeyExchangeMsg struct { - raw []byte - key []byte -} - -func (m *serverKeyExchangeMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - length := len(m.key) - x := make([]byte, length+4) - x[0] = typeServerKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.key) - - m.raw = x - return x -} - -func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - m.key = data[4:] - return true -} - -type certificateStatusMsg struct { - raw []byte - response []byte -} - -func (m *certificateStatusMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateStatus) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(statusTypeOCSP) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.response) - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateStatusMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - var statusType uint8 - if !s.Skip(4) || // message type and uint24 length field - !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || - !readUint24LengthPrefixed(&s, &m.response) || - len(m.response) == 0 || !s.Empty() { - return false - } - return true -} - -type serverHelloDoneMsg struct{} - -func (m *serverHelloDoneMsg) marshal() []byte { - x := make([]byte, 4) - x[0] = typeServerHelloDone - return x -} - -func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} - -type clientKeyExchangeMsg struct { - raw []byte - ciphertext []byte -} - -func (m *clientKeyExchangeMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - length := len(m.ciphertext) - x := make([]byte, length+4) - x[0] = typeClientKeyExchange - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - copy(x[4:], m.ciphertext) - - m.raw = x - return x -} - -func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { - m.raw = data - if len(data) < 4 { - return false - } - l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - if l != len(data)-4 { - return false - } - m.ciphertext = data[4:] - return true -} - -type finishedMsg struct { - raw []byte - verifyData []byte -} - -func (m *finishedMsg) marshal() []byte { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeFinished) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.verifyData) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *finishedMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - return s.Skip(1) && - readUint24LengthPrefixed(&s, &m.verifyData) && - s.Empty() -} - -type certificateRequestMsg struct { - raw []byte - // hasSignatureAlgorithm indicates whether this message includes a list of - // supported signature algorithms. This change was introduced with TLS 1.2. - hasSignatureAlgorithm bool - - certificateTypes []byte - supportedSignatureAlgorithms []SignatureScheme - certificateAuthorities [][]byte -} - -func (m *certificateRequestMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - // See RFC 4346, Section 7.4.4. - length := 1 + len(m.certificateTypes) + 2 - casLength := 0 - for _, ca := range m.certificateAuthorities { - casLength += 2 + len(ca) - } - length += casLength - - if m.hasSignatureAlgorithm { - length += 2 + 2*len(m.supportedSignatureAlgorithms) - } - - x = make([]byte, 4+length) - x[0] = typeCertificateRequest - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - - x[4] = uint8(len(m.certificateTypes)) - - copy(x[5:], m.certificateTypes) - y := x[5+len(m.certificateTypes):] - - if m.hasSignatureAlgorithm { - n := len(m.supportedSignatureAlgorithms) * 2 - y[0] = uint8(n >> 8) - y[1] = uint8(n) - y = y[2:] - for _, sigAlgo := range m.supportedSignatureAlgorithms { - y[0] = uint8(sigAlgo >> 8) - y[1] = uint8(sigAlgo) - y = y[2:] - } - } - - y[0] = uint8(casLength >> 8) - y[1] = uint8(casLength) - y = y[2:] - for _, ca := range m.certificateAuthorities { - y[0] = uint8(len(ca) >> 8) - y[1] = uint8(len(ca)) - y = y[2:] - copy(y, ca) - y = y[len(ca):] - } - - m.raw = x - return -} - -func (m *certificateRequestMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 5 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - numCertTypes := int(data[4]) - data = data[5:] - if numCertTypes == 0 || len(data) <= numCertTypes { - return false - } - - m.certificateTypes = make([]byte, numCertTypes) - if copy(m.certificateTypes, data) != numCertTypes { - return false - } - - data = data[numCertTypes:] - - if m.hasSignatureAlgorithm { - if len(data) < 2 { - return false - } - sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if sigAndHashLen&1 != 0 { - return false - } - if len(data) < int(sigAndHashLen) { - return false - } - numSigAlgos := sigAndHashLen / 2 - m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos) - for i := range m.supportedSignatureAlgorithms { - m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) - data = data[2:] - } - } - - if len(data) < 2 { - return false - } - casLength := uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - if len(data) < int(casLength) { - return false - } - cas := make([]byte, casLength) - copy(cas, data) - data = data[casLength:] - - m.certificateAuthorities = nil - for len(cas) > 0 { - if len(cas) < 2 { - return false - } - caLen := uint16(cas[0])<<8 | uint16(cas[1]) - cas = cas[2:] - - if len(cas) < int(caLen) { - return false - } - - m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) - cas = cas[caLen:] - } - - return len(data) == 0 -} - -type certificateVerifyMsg struct { - raw []byte - hasSignatureAlgorithm bool // format change introduced in TLS 1.2 - signatureAlgorithm SignatureScheme - signature []byte -} - -func (m *certificateVerifyMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - var b cryptobyte.Builder - b.AddUint8(typeCertificateVerify) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - if m.hasSignatureAlgorithm { - b.AddUint16(uint16(m.signatureAlgorithm)) - } - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.signature) - }) - }) - - m.raw = b.BytesOrPanic() - return m.raw -} - -func (m *certificateVerifyMsg) unmarshal(data []byte) bool { - m.raw = data - s := cryptobyte.String(data) - - if !s.Skip(4) { // message type and uint24 length field - return false - } - if m.hasSignatureAlgorithm { - if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { - return false - } - } - return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() -} - -type newSessionTicketMsg struct { - raw []byte - ticket []byte -} - -func (m *newSessionTicketMsg) marshal() (x []byte) { - if m.raw != nil { - return m.raw - } - - // See RFC 5077, Section 3.3. - ticketLen := len(m.ticket) - length := 2 + 4 + ticketLen - x = make([]byte, 4+length) - x[0] = typeNewSessionTicket - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[8] = uint8(ticketLen >> 8) - x[9] = uint8(ticketLen) - copy(x[10:], m.ticket) - - m.raw = x - - return -} - -func (m *newSessionTicketMsg) unmarshal(data []byte) bool { - m.raw = data - - if len(data) < 10 { - return false - } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false - } - - ticketLen := int(data[8])<<8 + int(data[9]) - if len(data)-10 != ticketLen { - return false - } - - m.ticket = data[10:] - - return true -} - -type helloRequestMsg struct { -} - -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} -} - -func (*helloRequestMsg) unmarshal(data []byte) bool { - return len(data) == 4 -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go deleted file mode 100644 index 7d3557d7279..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server.go +++ /dev/null @@ -1,905 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/subtle" - "crypto/x509" - "errors" - "fmt" - "hash" - "io" - "sync/atomic" - "time" -) - -// serverHandshakeState contains details of a server handshake in progress. -// It's discarded once the handshake has completed. -type serverHandshakeState struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - suite *cipherSuite - ecdheOk bool - ecSignOk bool - rsaDecryptOk bool - rsaSignOk bool - sessionState *sessionState - finishedHash finishedHash - masterSecret []byte - cert *Certificate -} - -// serverHandshake performs a TLS handshake as a server. -func (c *Conn) serverHandshake(ctx context.Context) error { - c.setAlternativeRecordLayer() - - clientHello, err := c.readClientHello(ctx) - if err != nil { - return err - } - - if c.vers == VersionTLS13 { - hs := serverHandshakeStateTLS13{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - return hs.handshake() - } else if c.extraConfig.usesAlternativeRecordLayer() { - // This should already have been caught by the check that the ClientHello doesn't - // offer any (supported) versions older than TLS 1.3. - // Check again to make sure we can't be tricked into using an older version. - c.sendAlert(alertProtocolVersion) - return errors.New("tls: negotiated TLS < 1.3 when using QUIC") - } - - hs := serverHandshakeState{ - c: c, - ctx: ctx, - clientHello: clientHello, - } - return hs.handshake() -} - -func (hs *serverHandshakeState) handshake() error { - c := hs.c - - if err := hs.processClientHello(); err != nil { - return err - } - - // For an overview of TLS handshaking, see RFC 5246, Section 7.3. - c.buffering = true - if hs.checkForResumption() { - // The client has included a session ticket and so we do an abbreviated handshake. - c.didResume = true - if err := hs.doResumeHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(c.serverFinished[:]); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - c.clientFinishedIsFirst = false - if err := hs.readFinished(nil); err != nil { - return err - } - } else { - // The client didn't include a session ticket, or it wasn't - // valid so we do a full handshake. - if err := hs.pickCipherSuite(); err != nil { - return err - } - if err := hs.doFullHandshake(); err != nil { - return err - } - if err := hs.establishKeys(); err != nil { - return err - } - if err := hs.readFinished(c.clientFinished[:]); err != nil { - return err - } - c.clientFinishedIsFirst = true - c.buffering = true - if err := hs.sendSessionTicket(); err != nil { - return err - } - if err := hs.sendFinished(nil); err != nil { - return err - } - if _, err := c.flush(); err != nil { - return err - } - } - - c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -// readClientHello reads a ClientHello message and selects the protocol version. -func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - msg, err := c.readHandshake() - if err != nil { - return nil, err - } - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return nil, unexpectedMessageError(clientHello, msg) - } - - var configForClient *config - originalConfig := c.config - if c.config.GetConfigForClient != nil { - chi := newClientHelloInfo(ctx, c, clientHello) - if cfc, err := c.config.GetConfigForClient(chi); err != nil { - c.sendAlert(alertInternalError) - return nil, err - } else if cfc != nil { - configForClient = fromConfig(cfc) - c.config = configForClient - } - } - c.ticketKeys = originalConfig.ticketKeys(configForClient) - - clientVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - clientVersions = supportedVersionsFromMax(clientHello.vers) - } - if c.extraConfig.usesAlternativeRecordLayer() { - // In QUIC, the client MUST NOT offer any old TLS versions. - // Here, we can only check that none of the other supported versions of this library - // (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here. - for _, ver := range clientVersions { - if ver == VersionTLS13 { - continue - } - for _, v := range supportedVersions { - if ver == v { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver) - } - } - } - // Make the config we're using allows us to use TLS 1.3. - if c.config.maxSupportedVersion() < VersionTLS13 { - c.sendAlert(alertInternalError) - return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3") - } - } - c.vers, ok = c.config.mutualVersion(clientVersions) - if !ok { - c.sendAlert(alertProtocolVersion) - return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) - } - c.haveVers = true - c.in.version = c.vers - c.out.version = c.vers - - return clientHello, nil -} - -func (hs *serverHandshakeState) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.hello.vers = c.vers - - foundCompression := false - // We only support null compression, so check that the client offered it. - for _, compression := range hs.clientHello.compressionMethods { - if compression == compressionNone { - foundCompression = true - break - } - } - - if !foundCompression { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client does not support uncompressed connections") - } - - hs.hello.random = make([]byte, 32) - serverRandom := hs.hello.random - // Downgrade protection canaries. See RFC 8446, Section 4.1.3. - maxVers := c.config.maxSupportedVersion() - if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary { - if c.vers == VersionTLS12 { - copy(serverRandom[24:], downgradeCanaryTLS12) - } else { - copy(serverRandom[24:], downgradeCanaryTLS11) - } - serverRandom = serverRandom[:24] - } - _, err := io.ReadFull(c.config.rand(), serverRandom) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported - hs.hello.compressionMethod = compressionNone - if len(hs.clientHello.serverName) > 0 { - c.serverName = hs.clientHello.serverName - } - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) - if err != nil { - c.sendAlert(alertNoApplicationProtocol) - return err - } - hs.hello.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - if hs.clientHello.scts { - hs.hello.scts = hs.cert.SignedCertificateTimestamps - } - - hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints) - - if hs.ecdheOk { - // Although omitting the ec_point_formats extension is permitted, some - // old OpenSSL version will refuse to handshake if not present. - // - // Per RFC 4492, section 5.1.2, implementations MUST support the - // uncompressed point format. See golang.org/issue/31943. - hs.hello.supportedPoints = []uint8{pointFormatUncompressed} - } - - if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { - switch priv.Public().(type) { - case *ecdsa.PublicKey: - hs.ecSignOk = true - case ed25519.PublicKey: - hs.ecSignOk = true - case *rsa.PublicKey: - hs.rsaSignOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public()) - } - } - if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { - switch priv.Public().(type) { - case *rsa.PublicKey: - hs.rsaDecryptOk = true - default: - c.sendAlert(alertInternalError) - return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public()) - } - } - - return nil -} - -// negotiateALPN picks a shared ALPN protocol that both sides support in server -// preference order. If ALPN is not configured or the peer doesn't support it, -// it returns "" and no error. -func negotiateALPN(serverProtos, clientProtos []string) (string, error) { - if len(serverProtos) == 0 || len(clientProtos) == 0 { - return "", nil - } - var http11fallback bool - for _, s := range serverProtos { - for _, c := range clientProtos { - if s == c { - return s, nil - } - if s == "h2" && c == "http/1.1" { - http11fallback = true - } - } - } - // As a special case, let http/1.1 clients connect to h2 servers as if they - // didn't support ALPN. We used not to enforce protocol overlap, so over - // time a number of HTTP servers were configured with only "h2", but - // expected to accept connections from "http/1.1" clients. See Issue 46310. - if http11fallback { - return "", nil - } - return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) -} - -// supportsECDHE returns whether ECDHE key exchanges can be used with this -// pre-TLS 1.3 client. -func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool { - supportsCurve := false - for _, curve := range supportedCurves { - if c.supportsCurve(curve) { - supportsCurve = true - break - } - } - - supportsPointFormat := false - for _, pointFormat := range supportedPoints { - if pointFormat == pointFormatUncompressed { - supportsPointFormat = true - break - } - } - - return supportsCurve && supportsPointFormat -} - -func (hs *serverHandshakeState) pickCipherSuite() error { - c := hs.c - - preferenceOrder := cipherSuitesPreferenceOrder - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceOrder = cipherSuitesPreferenceOrderNoAES - } - - configCipherSuites := c.config.cipherSuites() - preferenceList := make([]uint16, 0, len(configCipherSuites)) - for _, suiteID := range preferenceOrder { - for _, id := range configCipherSuites { - if id == suiteID { - preferenceList = append(preferenceList, id) - break - } - } - } - - hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk) - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // The client is doing a fallback connection. See RFC 7507. - if hs.clientHello.vers < c.config.maxSupportedVersion() { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - return nil -} - -func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool { - if c.flags&suiteECDHE != 0 { - if !hs.ecdheOk { - return false - } - if c.flags&suiteECSign != 0 { - if !hs.ecSignOk { - return false - } - } else if !hs.rsaSignOk { - return false - } - } else if !hs.rsaDecryptOk { - return false - } - if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 { - return false - } - return true -} - -// checkForResumption reports whether we should perform resumption on this connection. -func (hs *serverHandshakeState) checkForResumption() bool { - c := hs.c - - if c.config.SessionTicketsDisabled { - return false - } - - plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket) - if plaintext == nil { - return false - } - hs.sessionState = &sessionState{usedOldKey: usedOldKey} - ok := hs.sessionState.unmarshal(plaintext) - if !ok { - return false - } - - createdAt := time.Unix(int64(hs.sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - return false - } - - // Never resume a session for a different TLS version. - if c.vers != hs.sessionState.vers { - return false - } - - cipherSuiteOk := false - // Check that the client is still offering the ciphersuite in the session. - for _, id := range hs.clientHello.cipherSuites { - if id == hs.sessionState.cipherSuite { - cipherSuiteOk = true - break - } - } - if !cipherSuiteOk { - return false - } - - // Check that we also support the ciphersuite from the session. - hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite}, - c.config.cipherSuites(), hs.cipherSuiteOk) - if hs.suite == nil { - return false - } - - sessionHasClientCerts := len(hs.sessionState.certificates) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - return false - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - return false - } - - return true -} - -func (hs *serverHandshakeState) doResumeHandshake() error { - c := hs.c - - hs.hello.cipherSuite = hs.suite.id - c.cipherSuite = hs.suite.id - // We echo the client's session ID in the ServerHello to let it know - // that we're doing a resumption. - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.ticketSupported = hs.sessionState.usedOldKey - hs.finishedHash = newFinishedHash(c.vers, hs.suite) - hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - if err := c.processCertsFromClient(Certificate{ - Certificate: hs.sessionState.certificates, - }); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - hs.masterSecret = hs.sessionState.masterSecret - - return nil -} - -func (hs *serverHandshakeState) doFullHandshake() error { - c := hs.c - - if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { - hs.hello.ocspStapling = true - } - - hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled - hs.hello.cipherSuite = hs.suite.id - - hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - if c.config.ClientAuth == NoClientCert { - // No need to keep a full record of the handshake if client - // certificates won't be used. - hs.finishedHash.discardHandshakeBuffer() - } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - certMsg := new(certificateMsg) - certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - if hs.hello.ocspStapling { - certStatus := new(certificateStatusMsg) - certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { - return err - } - } - - keyAgreement := hs.suite.ka(c.vers) - skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { - return err - } - } - - var certReq *certificateRequestMsg - if c.config.ClientAuth >= RequestClientCert { - // Request a client certificate - certReq = new(certificateRequestMsg) - certReq.certificateTypes = []byte{ - byte(certTypeRSASign), - byte(certTypeECDSASign), - } - if c.vers >= VersionTLS12 { - certReq.hasSignatureAlgorithm = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms - } - - // An empty list of certificateAuthorities signals to - // the client that it may send any certificate in response - // to our request. When we know the CAs we trust, then - // we can send them down, so that the client can choose - // an appropriate certificate to give to us. - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - hs.finishedHash.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { - return err - } - } - - helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { - return err - } - - if _, err := c.flush(); err != nil { - return err - } - - var pub crypto.PublicKey // public key for client auth, if any - - msg, err := c.readHandshake() - if err != nil { - return err - } - - // If we requested a client certificate, then the client must send a - // certificate message, even if it's empty. - if c.config.ClientAuth >= RequestClientCert { - certMsg, ok := msg.(*certificateMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.finishedHash.Write(certMsg.marshal()) - - if err := c.processCertsFromClient(Certificate{ - Certificate: certMsg.certificates, - }); err != nil { - return err - } - if len(certMsg.certificates) != 0 { - pub = c.peerCertificates[0].PublicKey - } - - msg, err = c.readHandshake() - if err != nil { - return err - } - } - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - // Get client key exchange - ckx, ok := msg.(*clientKeyExchangeMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(ckx, msg) - } - hs.finishedHash.Write(ckx.marshal()) - - preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) - if err != nil { - c.sendAlert(alertHandshakeFailure) - return err - } - hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) - if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - - // If we received a client cert in response to our certificate request message, - // the client will send us a certificateVerifyMsg immediately after the - // clientKeyExchangeMsg. This message is a digest of all preceding - // handshake-layer messages that is signed using the private key corresponding - // to the client's certificate. This allows us to verify that the client is in - // possession of the private key of the certificate. - if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() - if err != nil { - return err - } - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - var sigType uint8 - var sigHash crypto.Hash - if c.vers >= VersionTLS12 { - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) - if err != nil { - c.sendAlert(alertIllegalParameter) - return err - } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) - if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - hs.finishedHash.Write(certVerify.marshal()) - } - - hs.finishedHash.discardHandshakeBuffer() - - return nil -} - -func (hs *serverHandshakeState) establishKeys() error { - c := hs.c - - clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := - keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) - - var clientCipher, serverCipher interface{} - var clientHash, serverHash hash.Hash - - if hs.suite.aead == nil { - clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) - clientHash = hs.suite.mac(clientMAC) - serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) - serverHash = hs.suite.mac(serverMAC) - } else { - clientCipher = hs.suite.aead(clientKey, clientIV) - serverCipher = hs.suite.aead(serverKey, serverIV) - } - - c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) - c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) - - return nil -} - -func (hs *serverHandshakeState) readFinished(out []byte) error { - c := hs.c - - if err := c.readChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - clientFinished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientFinished, msg) - } - - verify := hs.finishedHash.clientSum(hs.masterSecret) - if len(verify) != len(clientFinished.verifyData) || - subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: client's Finished message is incorrect") - } - - hs.finishedHash.Write(clientFinished.marshal()) - copy(out, verify) - return nil -} - -func (hs *serverHandshakeState) sendSessionTicket() error { - // ticketSupported is set in a resumption handshake if the - // ticket from the client was encrypted with an old session - // ticket key and thus a refreshed ticket should be sent. - if !hs.hello.ticketSupported { - return nil - } - - c := hs.c - m := new(newSessionTicketMsg) - - createdAt := uint64(c.config.time().Unix()) - if hs.sessionState != nil { - // If this is re-wrapping an old key, then keep - // the original time it was created. - createdAt = hs.sessionState.createdAt - } - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionState{ - vers: c.vers, - cipherSuite: hs.suite.id, - createdAt: createdAt, - masterSecret: hs.masterSecret, - certificates: certsFromClient, - } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) - if err != nil { - return err - } - - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeState) sendFinished(out []byte) error { - c := hs.c - - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { - return err - } - - finished := new(finishedMsg) - finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - copy(out, finished.verifyData) - - return nil -} - -// processCertsFromClient takes a chain of client certificates either from a -// Certificates message or from a sessionState and verifies them. It returns -// the public key of the leaf certificate. -func (c *Conn) processCertsFromClient(certificate Certificate) error { - certificates := certificate.Certificate - certs := make([]*x509.Certificate, len(certificates)) - var err error - for i, asn1Data := range certificates { - if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to parse client certificate: " + err.Error()) - } - } - - if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: client didn't provide a certificate") - } - - if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { - opts := x509.VerifyOptions{ - Roots: c.config.ClientCAs, - CurrentTime: c.config.time(), - Intermediates: x509.NewCertPool(), - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - for _, cert := range certs[1:] { - opts.Intermediates.AddCert(cert) - } - - chains, err := certs[0].Verify(opts) - if err != nil { - c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to verify client certificate: " + err.Error()) - } - - c.verifiedChains = chains - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps - - if len(certs) > 0 { - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - } - - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - return nil -} - -func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { - supportedVersions := clientHello.supportedVersions - if len(clientHello.supportedVersions) == 0 { - supportedVersions = supportedVersionsFromMax(clientHello.vers) - } - - return toClientHelloInfo(&clientHelloInfo{ - CipherSuites: clientHello.cipherSuites, - ServerName: clientHello.serverName, - SupportedCurves: clientHello.supportedCurves, - SupportedPoints: clientHello.supportedPoints, - SignatureSchemes: clientHello.supportedSignatureAlgorithms, - SupportedProtos: clientHello.alpnProtocols, - SupportedVersions: supportedVersions, - Conn: c.conn, - config: toConfig(c.config), - ctx: ctx, - }) -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go b/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go deleted file mode 100644 index 0c200605e6e..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/handshake_server_tls13.go +++ /dev/null @@ -1,896 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "context" - "crypto" - "crypto/hmac" - "crypto/rsa" - "errors" - "hash" - "io" - "sync/atomic" - "time" -) - -// maxClientPSKIdentities is the number of client PSK identities the server will -// attempt to validate. It will ignore the rest not to let cheap ClientHello -// messages cause too much work in session ticket decryption attempts. -const maxClientPSKIdentities = 5 - -type serverHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - clientHello *clientHelloMsg - hello *serverHelloMsg - alpnNegotiationErr error - encryptedExtensions *encryptedExtensionsMsg - sentDummyCCS bool - usingPSK bool - suite *cipherSuiteTLS13 - cert *Certificate - sigAlg SignatureScheme - earlySecret []byte - sharedKey []byte - handshakeSecret []byte - masterSecret []byte - trafficSecret []byte // client_application_traffic_secret_0 - transcript hash.Hash - clientFinished []byte -} - -func (hs *serverHandshakeStateTLS13) handshake() error { - c := hs.c - - // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. - if err := hs.processClientHello(); err != nil { - return err - } - if err := hs.checkForResumption(); err != nil { - return err - } - if err := hs.pickCertificate(); err != nil { - return err - } - c.buffering = true - if err := hs.sendServerParameters(); err != nil { - return err - } - if err := hs.sendServerCertificate(); err != nil { - return err - } - if err := hs.sendServerFinished(); err != nil { - return err - } - // Note that at this point we could start sending application data without - // waiting for the client's second flight, but the application might not - // expect the lack of replay protection of the ClientHello parameters. - if _, err := c.flush(); err != nil { - return err - } - if err := hs.readClientCertificate(); err != nil { - return err - } - if err := hs.readClientFinished(); err != nil { - return err - } - - atomic.StoreUint32(&c.handshakeStatus, 1) - - return nil -} - -func (hs *serverHandshakeStateTLS13) processClientHello() error { - c := hs.c - - hs.hello = new(serverHelloMsg) - hs.encryptedExtensions = new(encryptedExtensionsMsg) - - // TLS 1.3 froze the ServerHello.legacy_version field, and uses - // supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1. - hs.hello.vers = VersionTLS12 - hs.hello.supportedVersion = c.vers - - if len(hs.clientHello.supportedVersions) == 0 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client used the legacy version field to negotiate TLS 1.3") - } - - // Abort if the client is doing a fallback and landing lower than what we - // support. See RFC 7507, which however does not specify the interaction - // with supported_versions. The only difference is that with - // supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4] - // handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case, - // it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to - // TLS 1.2, because a TLS 1.3 server would abort here. The situation before - // supported_versions was not better because there was just no way to do a - // TLS 1.4 handshake without risking the server selecting TLS 1.3. - for _, id := range hs.clientHello.cipherSuites { - if id == TLS_FALLBACK_SCSV { - // Use c.vers instead of max(supported_versions) because an attacker - // could defeat this by adding an arbitrary high version otherwise. - if c.vers < c.config.maxSupportedVersion() { - c.sendAlert(alertInappropriateFallback) - return errors.New("tls: client using inappropriate protocol fallback") - } - break - } - } - - if len(hs.clientHello.compressionMethods) != 1 || - hs.clientHello.compressionMethods[0] != compressionNone { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: TLS 1.3 client supports illegal compression methods") - } - - hs.hello.random = make([]byte, 32) - if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil { - c.sendAlert(alertInternalError) - return err - } - - if len(hs.clientHello.secureRenegotiation) != 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: initial handshake had non-empty renegotiation extension") - } - - hs.hello.sessionId = hs.clientHello.sessionId - hs.hello.compressionMethod = compressionNone - - if hs.suite == nil { - var preferenceList []uint16 - for _, suiteID := range c.config.CipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - preferenceList = append(preferenceList, suiteID) - break - } - } - } - if len(preferenceList) == 0 { - preferenceList = defaultCipherSuitesTLS13 - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = defaultCipherSuitesTLS13NoAES - } - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) - if hs.suite != nil { - break - } - } - } - if hs.suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no cipher suite supported by both client and server") - } - c.cipherSuite = hs.suite.id - hs.hello.cipherSuite = hs.suite.id - hs.transcript = hs.suite.hash.New() - - // Pick the ECDHE group in server preference order, but give priority to - // groups with a key share, to avoid a HelloRetryRequest round-trip. - var selectedGroup CurveID - var clientKeyShare *keyShare -GroupSelection: - for _, preferredGroup := range c.config.curvePreferences() { - for _, ks := range hs.clientHello.keyShares { - if ks.group == preferredGroup { - selectedGroup = ks.group - clientKeyShare = &ks - break GroupSelection - } - } - if selectedGroup != 0 { - continue - } - for _, group := range hs.clientHello.supportedCurves { - if group == preferredGroup { - selectedGroup = group - break - } - } - } - if selectedGroup == 0 { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: no ECDHE curve supported by both client and server") - } - if clientKeyShare == nil { - if err := hs.doHelloRetryRequest(selectedGroup); err != nil { - return err - } - clientKeyShare = &hs.clientHello.keyShares[0] - } - - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) - if hs.sharedKey == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") - } - - c.serverName = hs.clientHello.serverName - - if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil { - c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions) - } - - selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols) - if err != nil { - hs.alpnNegotiationErr = err - } - hs.encryptedExtensions.alpnProtocol = selectedProto - c.clientProtocol = selectedProto - - return nil -} - -func (hs *serverHandshakeStateTLS13) checkForResumption() error { - c := hs.c - - if c.config.SessionTicketsDisabled { - return nil - } - - modeOK := false - for _, mode := range hs.clientHello.pskModes { - if mode == pskModeDHE { - modeOK = true - break - } - } - if !modeOK { - return nil - } - - if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid or missing PSK binders") - } - if len(hs.clientHello.pskIdentities) == 0 { - return nil - } - - for i, identity := range hs.clientHello.pskIdentities { - if i >= maxClientPSKIdentities { - break - } - - plaintext, _ := c.decryptTicket(identity.label) - if plaintext == nil { - continue - } - sessionState := new(sessionStateTLS13) - if ok := sessionState.unmarshal(plaintext); !ok { - continue - } - - if hs.clientHello.earlyData { - if sessionState.maxEarlyData == 0 { - c.sendAlert(alertUnsupportedExtension) - return errors.New("tls: client sent unexpected early data") - } - - if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol && - c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 && - c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) { - hs.encryptedExtensions.earlyData = true - c.used0RTT = true - } - } - - createdAt := time.Unix(int64(sessionState.createdAt), 0) - if c.config.time().Sub(createdAt) > maxSessionTicketLifetime { - continue - } - - // We don't check the obfuscated ticket age because it's affected by - // clock skew and it's only a freshness signal useful for shrinking the - // window for replay attacks, which don't affect us as we don't do 0-RTT. - - pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite) - if pskSuite == nil || pskSuite.hash != hs.suite.hash { - continue - } - - // PSK connections don't re-establish client certificates, but carry - // them over in the session ticket. Ensure the presence of client certs - // in the ticket is consistent with the configured requirements. - sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0 - needClientCerts := requiresClientCert(c.config.ClientAuth) - if needClientCerts && !sessionHasClientCerts { - continue - } - if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { - continue - } - - psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption", - nil, hs.suite.hash.Size()) - hs.earlySecret = hs.suite.extract(psk, nil) - binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil) - // Clone the transcript in case a HelloRetryRequest was recorded. - transcript := cloneHash(hs.transcript, hs.suite.hash) - if transcript == nil { - c.sendAlert(alertInternalError) - return errors.New("tls: internal error: failed to clone hash") - } - transcript.Write(hs.clientHello.marshalWithoutBinders()) - pskBinder := hs.suite.finishedHash(binderKey, transcript) - if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid PSK binder") - } - - c.didResume = true - if err := c.processCertsFromClient(sessionState.certificate); err != nil { - return err - } - - h := cloneHash(hs.transcript, hs.suite.hash) - h.Write(hs.clientHello.marshal()) - if hs.encryptedExtensions.earlyData { - clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) - c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) - if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) - return err - } - } - - hs.hello.selectedIdentityPresent = true - hs.hello.selectedIdentity = uint16(i) - hs.usingPSK = true - return nil - } - - return nil -} - -// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler -// interfaces implemented by standard library hashes to clone the state of in -// to a new instance of h. It returns nil if the operation fails. -func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { - // Recreate the interface to avoid importing encoding. - type binaryMarshaler interface { - MarshalBinary() (data []byte, err error) - UnmarshalBinary(data []byte) error - } - marshaler, ok := in.(binaryMarshaler) - if !ok { - return nil - } - state, err := marshaler.MarshalBinary() - if err != nil { - return nil - } - out := h.New() - unmarshaler, ok := out.(binaryMarshaler) - if !ok { - return nil - } - if err := unmarshaler.UnmarshalBinary(state); err != nil { - return nil - } - return out -} - -func (hs *serverHandshakeStateTLS13) pickCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - // signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3. - if len(hs.clientHello.supportedSignatureAlgorithms) == 0 { - return c.sendAlert(alertMissingExtension) - } - - certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello)) - if err != nil { - if err == errNoCertificates { - c.sendAlert(alertUnrecognizedName) - } else { - c.sendAlert(alertInternalError) - } - return err - } - hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms) - if err != nil { - // getCertificate returned a certificate that is unsupported or - // incompatible with the client's signature algorithms. - c.sendAlert(alertHandshakeFailure) - return err - } - hs.cert = certificate - - return nil -} - -// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility -// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4. -func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { - if hs.sentDummyCCS { - return nil - } - hs.sentDummyCCS = true - - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err -} - -func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { - c := hs.c - - // The first ClientHello gets double-hashed into the transcript upon a - // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) - chHash := hs.transcript.Sum(nil) - hs.transcript.Reset() - hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) - hs.transcript.Write(chHash) - - helloRetryRequest := &serverHelloMsg{ - vers: hs.hello.vers, - random: helloRetryRequestRandom, - sessionId: hs.hello.sessionId, - cipherSuite: hs.hello.cipherSuite, - compressionMethod: hs.hello.compressionMethod, - supportedVersion: hs.hello.supportedVersion, - selectedGroup: selectedGroup, - } - - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - msg, err := c.readHandshake() - if err != nil { - return err - } - - clientHello, ok := msg.(*clientHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(clientHello, msg) - } - - if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client sent invalid key share in second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client indicated early data in second ClientHello") - } - - if illegalClientHelloChange(clientHello, hs.clientHello) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client illegally modified second ClientHello") - } - - if clientHello.earlyData { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client offered 0-RTT data in second ClientHello") - } - - hs.clientHello = clientHello - return nil -} - -// illegalClientHelloChange reports whether the two ClientHello messages are -// different, with the exception of the changes allowed before and after a -// HelloRetryRequest. See RFC 8446, Section 4.1.2. -func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { - if len(ch.supportedVersions) != len(ch1.supportedVersions) || - len(ch.cipherSuites) != len(ch1.cipherSuites) || - len(ch.supportedCurves) != len(ch1.supportedCurves) || - len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) || - len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) || - len(ch.alpnProtocols) != len(ch1.alpnProtocols) { - return true - } - for i := range ch.supportedVersions { - if ch.supportedVersions[i] != ch1.supportedVersions[i] { - return true - } - } - for i := range ch.cipherSuites { - if ch.cipherSuites[i] != ch1.cipherSuites[i] { - return true - } - } - for i := range ch.supportedCurves { - if ch.supportedCurves[i] != ch1.supportedCurves[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithms { - if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] { - return true - } - } - for i := range ch.supportedSignatureAlgorithmsCert { - if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] { - return true - } - } - for i := range ch.alpnProtocols { - if ch.alpnProtocols[i] != ch1.alpnProtocols[i] { - return true - } - } - return ch.vers != ch1.vers || - !bytes.Equal(ch.random, ch1.random) || - !bytes.Equal(ch.sessionId, ch1.sessionId) || - !bytes.Equal(ch.compressionMethods, ch1.compressionMethods) || - ch.serverName != ch1.serverName || - ch.ocspStapling != ch1.ocspStapling || - !bytes.Equal(ch.supportedPoints, ch1.supportedPoints) || - ch.ticketSupported != ch1.ticketSupported || - !bytes.Equal(ch.sessionTicket, ch1.sessionTicket) || - ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported || - !bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) || - ch.scts != ch1.scts || - !bytes.Equal(ch.cookie, ch1.cookie) || - !bytes.Equal(ch.pskModes, ch1.pskModes) -} - -func (hs *serverHandshakeStateTLS13) sendServerParameters() error { - c := hs.c - - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { - return err - } - - if err := hs.sendDummyChangeCipherSpec(); err != nil { - return err - } - - earlySecret := hs.earlySecret - if earlySecret == nil { - earlySecret = hs.suite.extract(nil, nil) - } - hs.handshakeSecret = hs.suite.extract(hs.sharedKey, - hs.suite.deriveSecret(earlySecret, "derived", nil)) - - clientSecret := hs.suite.deriveSecret(hs.handshakeSecret, - clientHandshakeTrafficLabel, hs.transcript) - c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret) - c.in.setTrafficSecret(hs.suite, clientSecret) - serverSecret := hs.suite.deriveSecret(hs.handshakeSecret, - serverHandshakeTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - if hs.alpnNegotiationErr != nil { - c.sendAlert(alertNoApplicationProtocol) - return hs.alpnNegotiationErr - } - if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil { - hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) - } - - hs.transcript.Write(hs.encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) requestClientCert() bool { - return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK -} - -func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { - c := hs.c - - // Only one of PSK and certificates are used at a time. - if hs.usingPSK { - return nil - } - - if hs.requestClientCert() { - // Request a client certificate - certReq := new(certificateRequestMsgTLS13) - certReq.ocspStapling = true - certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms - if c.config.ClientCAs != nil { - certReq.certificateAuthorities = c.config.ClientCAs.Subjects() - } - - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { - return err - } - } - - certMsg := new(certificateMsgTLS13) - - certMsg.certificate = *hs.cert - certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 - certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { - return err - } - - certVerifyMsg := new(certificateVerifyMsg) - certVerifyMsg.hasSignatureAlgorithm = true - certVerifyMsg.signatureAlgorithm = hs.sigAlg - - sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg) - if err != nil { - return c.sendAlert(alertInternalError) - } - - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) - if err != nil { - public := hs.cert.PrivateKey.(crypto.Signer).Public() - if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && - rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS - c.sendAlert(alertHandshakeFailure) - } else { - c.sendAlert(alertInternalError) - } - return errors.New("tls: failed to sign handshake: " + err.Error()) - } - certVerifyMsg.signature = sig - - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) sendServerFinished() error { - c := hs.c - - finished := &finishedMsg{ - verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), - } - - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { - return err - } - - // Derive secrets that take context through the server Finished. - - hs.masterSecret = hs.suite.extract(nil, - hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil)) - - hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret, - clientApplicationTrafficLabel, hs.transcript) - serverSecret := hs.suite.deriveSecret(hs.masterSecret, - serverApplicationTrafficLabel, hs.transcript) - c.out.exportKey(EncryptionApplication, hs.suite, serverSecret) - c.out.setTrafficSecret(hs.suite, serverSecret) - - err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - - c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript) - - // If we did not request client certificates, at this point we can - // precompute the client finished and roll the transcript forward to send - // session tickets in our first flight. - if !hs.requestClientCert() { - if err := hs.sendSessionTickets(); err != nil { - return err - } - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool { - if hs.c.config.SessionTicketsDisabled { - return false - } - - // Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9. - for _, pskMode := range hs.clientHello.pskModes { - if pskMode == pskModeDHE { - return true - } - } - return false -} - -func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { - c := hs.c - - hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript) - finishedMsg := &finishedMsg{ - verifyData: hs.clientFinished, - } - hs.transcript.Write(finishedMsg.marshal()) - - if !hs.shouldSendSessionTickets() { - return nil - } - - c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret, - resumptionLabel, hs.transcript) - - // Don't send session tickets when the alternative record layer is set. - // Instead, save the resumption secret on the Conn. - // Session tickets can then be generated by calling Conn.GetSessionTicket(). - if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil { - return nil - } - - m, err := hs.c.getSessionTicketMsg(nil) - if err != nil { - return err - } - - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientCertificate() error { - c := hs.c - - if !hs.requestClientCert() { - // Make sure the connection is still being verified whether or not - // the server requested a client certificate. - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - return nil - } - - // If we requested a client certificate, then the client must send a - // certificate message. If it's empty, no CertificateVerify is sent. - - msg, err := c.readHandshake() - if err != nil { - return err - } - - certMsg, ok := msg.(*certificateMsgTLS13) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certMsg, msg) - } - hs.transcript.Write(certMsg.marshal()) - - if err := c.processCertsFromClient(certMsg.certificate); err != nil { - return err - } - - if c.config.VerifyConnection != nil { - if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - - if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() - if err != nil { - return err - } - - certVerify, ok := msg.(*certificateVerifyMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(certVerify, msg) - } - - // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm) - if err != nil { - return c.sendAlert(alertInternalError) - } - if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: client certificate used with invalid signature algorithm") - } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) - if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, - sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } - - hs.transcript.Write(certVerify.marshal()) - } - - // If we waited until the client certificates to send session tickets, we - // are ready to do it now. - if err := hs.sendSessionTickets(); err != nil { - return err - } - - return nil -} - -func (hs *serverHandshakeStateTLS13) readClientFinished() error { - c := hs.c - - msg, err := c.readHandshake() - if err != nil { - return err - } - - finished, ok := msg.(*finishedMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(finished, msg) - } - - if !hmac.Equal(hs.clientFinished, finished.verifyData) { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid client finished hash") - } - - c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret) - c.in.setTrafficSecret(hs.suite, hs.trafficSecret) - - return nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go b/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go deleted file mode 100644 index 453a8dcf080..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-17/key_agreement.go +++ /dev/null @@ -1,357 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/md5" - "crypto/rsa" - "crypto/sha1" - "crypto/x509" - "errors" - "fmt" - "io" -) - -// a keyAgreement implements the client and server side of a TLS key agreement -// protocol by generating and processing key exchange messages. -type keyAgreement interface { - // On the server side, the first two methods are called in order. - - // In the case that the key agreement protocol doesn't use a - // ServerKeyExchange message, generateServerKeyExchange can return nil, - // nil. - generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) - processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) - - // On the client side, the next two methods are called in order. - - // This method may not be called if the server doesn't send a - // ServerKeyExchange message. - processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error - generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) -} - -var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") -var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") - -// rsaKeyAgreement implements the standard TLS key agreement where the client -// encrypts the pre-master secret to the server's public key. -type rsaKeyAgreement struct{} - -func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - return nil, nil -} - -func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) < 2 { - return nil, errClientKeyExchange - } - ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) - if ciphertextLen != len(ckx.ciphertext)-2 { - return nil, errClientKeyExchange - } - ciphertext := ckx.ciphertext[2:] - - priv, ok := cert.PrivateKey.(crypto.Decrypter) - if !ok { - return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") - } - // Perform constant time RSA PKCS #1 v1.5 decryption - preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48}) - if err != nil { - return nil, err - } - // We don't check the version number in the premaster secret. For one, - // by checking it, we would leak information about the validity of the - // encrypted pre-master secret. Secondly, it provides only a small - // benefit against a downgrade attack and some implementations send the - // wrong version anyway. See the discussion at the end of section - // 7.4.7.1 of RFC 4346. - return preMasterSecret, nil -} - -func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - return errors.New("tls: unexpected ServerKeyExchange") -} - -func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - preMasterSecret := make([]byte, 48) - preMasterSecret[0] = byte(clientHello.vers >> 8) - preMasterSecret[1] = byte(clientHello.vers) - _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) - if err != nil { - return nil, nil, err - } - - rsaKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite") - } - encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret) - if err != nil { - return nil, nil, err - } - ckx := new(clientKeyExchangeMsg) - ckx.ciphertext = make([]byte, len(encrypted)+2) - ckx.ciphertext[0] = byte(len(encrypted) >> 8) - ckx.ciphertext[1] = byte(len(encrypted)) - copy(ckx.ciphertext[2:], encrypted) - return preMasterSecret, ckx, nil -} - -// sha1Hash calculates a SHA1 hash over the given byte slices. -func sha1Hash(slices [][]byte) []byte { - hsha1 := sha1.New() - for _, slice := range slices { - hsha1.Write(slice) - } - return hsha1.Sum(nil) -} - -// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the -// concatenation of an MD5 and SHA1 hash. -func md5SHA1Hash(slices [][]byte) []byte { - md5sha1 := make([]byte, md5.Size+sha1.Size) - hmd5 := md5.New() - for _, slice := range slices { - hmd5.Write(slice) - } - copy(md5sha1, hmd5.Sum(nil)) - copy(md5sha1[md5.Size:], sha1Hash(slices)) - return md5sha1 -} - -// hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for >= TLS 1.2) or using a default based on -// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't -// do pre-hashing, it returns the concatenation of the slices. -func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { - var signed []byte - for _, slice := range slices { - signed = append(signed, slice...) - } - return signed - } - if version >= VersionTLS12 { - h := hashFunc.New() - for _, slice := range slices { - h.Write(slice) - } - digest := h.Sum(nil) - return digest - } - if sigType == signatureECDSA { - return sha1Hash(slices) - } - return md5SHA1Hash(slices) -} - -// ecdheKeyAgreement implements a TLS key agreement where the server -// generates an ephemeral EC public/private key pair and signs it. The -// pre-master secret is then calculated using ECDH. The signature may -// be ECDSA, Ed25519 or RSA. -type ecdheKeyAgreement struct { - version uint16 - isRSA bool - params ecdheParameters - - // ckx and preMasterSecret are generated in processServerKeyExchange - // and returned in generateClientKeyExchange. - ckx *clientKeyExchangeMsg - preMasterSecret []byte -} - -func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - var curveID CurveID - for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { - curveID = c - break - } - } - - if curveID == 0 { - return nil, errors.New("tls: no supported elliptic curves offered") - } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - - params, err := generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, err - } - ka.params = params - - // See RFC 4492, Section 5.4. - ecdhePublic := params.PublicKey() - serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) - serverECDHEParams[0] = 3 // named curve - serverECDHEParams[1] = byte(curveID >> 8) - serverECDHEParams[2] = byte(curveID) - serverECDHEParams[3] = byte(len(ecdhePublic)) - copy(serverECDHEParams[4:], ecdhePublic) - - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) - } - - var signatureAlgorithm SignatureScheme - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) - if err != nil { - return nil, err - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return nil, err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) - if err != nil { - return nil, err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") - } - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) - - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := priv.Sign(config.rand(), signed, signOpts) - if err != nil { - return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) - } - - skx := new(serverKeyExchangeMsg) - sigAndHashLen := 0 - if ka.version >= VersionTLS12 { - sigAndHashLen = 2 - } - skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig)) - copy(skx.key, serverECDHEParams) - k := skx.key[len(serverECDHEParams):] - if ka.version >= VersionTLS12 { - k[0] = byte(signatureAlgorithm >> 8) - k[1] = byte(signatureAlgorithm) - k = k[2:] - } - k[0] = byte(len(sig) >> 8) - k[1] = byte(len(sig)) - copy(k[2:], sig) - - return skx, nil -} - -func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { - if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { - return nil, errClientKeyExchange - } - - preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) - if preMasterSecret == nil { - return nil, errClientKeyExchange - } - - return preMasterSecret, nil -} - -func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { - if len(skx.key) < 4 { - return errServerKeyExchange - } - if skx.key[0] != 3 { // named curve - return errors.New("tls: server selected unsupported curve") - } - curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) - - publicLen := int(skx.key[3]) - if publicLen+4 > len(skx.key) { - return errServerKeyExchange - } - serverECDHEParams := skx.key[:4+publicLen] - publicKey := serverECDHEParams[4:] - - sig := skx.key[4+publicLen:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return errors.New("tls: server selected unsupported curve") - } - - params, err := generateECDHEParameters(config.rand(), curveID) - if err != nil { - return err - } - ka.params = params - - ka.preMasterSecret = params.SharedKey(publicKey) - if ka.preMasterSecret == nil { - return errServerKeyExchange - } - - ourPublicKey := params.PublicKey() - ka.ckx = new(clientKeyExchangeMsg) - ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) - ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) - copy(ka.ckx.ciphertext[1:], ourPublicKey) - - var sigType uint8 - var sigHash crypto.Hash - if ka.version >= VersionTLS12 { - signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - sig = sig[2:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { - return errors.New("tls: certificate used with invalid signature algorithm") - } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) - if err != nil { - return err - } - } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) - if err != nil { - return err - } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return errServerKeyExchange - } - - sigLen := int(sig[0])<<8 | int(sig[1]) - if sigLen+2 != len(sig) { - return errServerKeyExchange - } - sig = sig[2:] - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) - if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) - } - return nil -} - -func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { - if ka.ckx == nil { - return nil, nil, errors.New("tls: missing ServerKeyExchange message") - } - - return ka.preMasterSecret, ka.ckx, nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/README.md b/vendor/github.com/marten-seemann/qtls-go1-18/README.md deleted file mode 100644 index 3e902212721..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-18/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# qtls - -[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-17.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-17) -[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-17/actions/workflows/go-test.yml) - -This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/alert.go b/vendor/github.com/marten-seemann/qtls-go1-18/alert.go deleted file mode 100644 index 3feac79be84..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-18/alert.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import "strconv" - -type alert uint8 - -// Alert is a TLS alert -type Alert = alert - -const ( - // alert level - alertLevelWarning = 1 - alertLevelError = 2 -) - -const ( - alertCloseNotify alert = 0 - alertUnexpectedMessage alert = 10 - alertBadRecordMAC alert = 20 - alertDecryptionFailed alert = 21 - alertRecordOverflow alert = 22 - alertDecompressionFailure alert = 30 - alertHandshakeFailure alert = 40 - alertBadCertificate alert = 42 - alertUnsupportedCertificate alert = 43 - alertCertificateRevoked alert = 44 - alertCertificateExpired alert = 45 - alertCertificateUnknown alert = 46 - alertIllegalParameter alert = 47 - alertUnknownCA alert = 48 - alertAccessDenied alert = 49 - alertDecodeError alert = 50 - alertDecryptError alert = 51 - alertExportRestriction alert = 60 - alertProtocolVersion alert = 70 - alertInsufficientSecurity alert = 71 - alertInternalError alert = 80 - alertInappropriateFallback alert = 86 - alertUserCanceled alert = 90 - alertNoRenegotiation alert = 100 - alertMissingExtension alert = 109 - alertUnsupportedExtension alert = 110 - alertCertificateUnobtainable alert = 111 - alertUnrecognizedName alert = 112 - alertBadCertificateStatusResponse alert = 113 - alertBadCertificateHashValue alert = 114 - alertUnknownPSKIdentity alert = 115 - alertCertificateRequired alert = 116 - alertNoApplicationProtocol alert = 120 -) - -var alertText = map[alert]string{ - alertCloseNotify: "close notify", - alertUnexpectedMessage: "unexpected message", - alertBadRecordMAC: "bad record MAC", - alertDecryptionFailed: "decryption failed", - alertRecordOverflow: "record overflow", - alertDecompressionFailure: "decompression failure", - alertHandshakeFailure: "handshake failure", - alertBadCertificate: "bad certificate", - alertUnsupportedCertificate: "unsupported certificate", - alertCertificateRevoked: "revoked certificate", - alertCertificateExpired: "expired certificate", - alertCertificateUnknown: "unknown certificate", - alertIllegalParameter: "illegal parameter", - alertUnknownCA: "unknown certificate authority", - alertAccessDenied: "access denied", - alertDecodeError: "error decoding message", - alertDecryptError: "error decrypting message", - alertExportRestriction: "export restriction", - alertProtocolVersion: "protocol version not supported", - alertInsufficientSecurity: "insufficient security level", - alertInternalError: "internal error", - alertInappropriateFallback: "inappropriate fallback", - alertUserCanceled: "user canceled", - alertNoRenegotiation: "no renegotiation", - alertMissingExtension: "missing extension", - alertUnsupportedExtension: "unsupported extension", - alertCertificateUnobtainable: "certificate unobtainable", - alertUnrecognizedName: "unrecognized name", - alertBadCertificateStatusResponse: "bad certificate status response", - alertBadCertificateHashValue: "bad certificate hash value", - alertUnknownPSKIdentity: "unknown PSK identity", - alertCertificateRequired: "certificate required", - alertNoApplicationProtocol: "no application protocol", -} - -func (e alert) String() string { - s, ok := alertText[e] - if ok { - return "tls: " + s - } - return "tls: alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e alert) Error() string { - return e.String() -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/prf.go b/vendor/github.com/marten-seemann/qtls-go1-18/prf.go deleted file mode 100644 index 9eb0221a0c0..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-18/prf.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "errors" - "fmt" - "hash" -) - -// Split a premaster secret in two as specified in RFC 4346, Section 5. -func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2] - s2 = secret[len(secret)/2:] - return -} - -// pHash implements the P_hash function, as defined in RFC 4346, Section 5. -func pHash(result, secret, seed []byte, hash func() hash.Hash) { - h := hmac.New(hash, secret) - h.Write(seed) - a := h.Sum(nil) - - j := 0 - for j < len(result) { - h.Reset() - h.Write(a) - h.Write(seed) - b := h.Sum(nil) - copy(result[j:], b) - j += len(b) - - h.Reset() - h.Write(a) - a = h.Sum(nil) - } -} - -// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. -func prf10(result, secret, label, seed []byte) { - hashSHA1 := sha1.New - hashMD5 := md5.New - - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - s1, s2 := splitPreMasterSecret(secret) - pHash(result, s1, labelAndSeed, hashMD5) - result2 := make([]byte, len(result)) - pHash(result2, s2, labelAndSeed, hashSHA1) - - for i, b := range result2 { - result[i] ^= b - } -} - -// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. -func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { - return func(result, secret, label, seed []byte) { - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - pHash(result, secret, labelAndSeed, hashFunc) - } -} - -const ( - masterSecretLength = 48 // Length of a master secret in TLS 1.1. - finishedVerifyLength = 12 // Length of verify_data in a Finished message. -) - -var masterSecretLabel = []byte("master secret") -var keyExpansionLabel = []byte("key expansion") -var clientFinishedLabel = []byte("client finished") -var serverFinishedLabel = []byte("server finished") - -func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { - switch version { - case VersionTLS10, VersionTLS11: - return prf10, crypto.Hash(0) - case VersionTLS12: - if suite.flags&suiteSHA384 != 0 { - return prf12(sha512.New384), crypto.SHA384 - } - return prf12(sha256.New), crypto.SHA256 - default: - panic("unknown version") - } -} - -func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { - prf, _ := prfAndHashForVersion(version, suite) - return prf -} - -// masterFromPreMasterSecret generates the master secret from the pre-master -// secret. See RFC 5246, Section 8.1. -func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { - seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - masterSecret := make([]byte, masterSecretLength) - prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) - return masterSecret -} - -// keysFromMasterSecret generates the connection keys from the master -// secret, given the lengths of the MAC key, cipher key and IV, as defined in -// RFC 2246, Section 6.3. -func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { - seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) - seed = append(seed, serverRandom...) - seed = append(seed, clientRandom...) - - n := 2*macLen + 2*keyLen + 2*ivLen - keyMaterial := make([]byte, n) - prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) - clientMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - serverMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - clientKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - serverKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - clientIV = keyMaterial[:ivLen] - keyMaterial = keyMaterial[ivLen:] - serverIV = keyMaterial[:ivLen] - return -} - -func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { - var buffer []byte - if version >= VersionTLS12 { - buffer = []byte{} - } - - prf, hash := prfAndHashForVersion(version, cipherSuite) - if hash != 0 { - return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} - } - - return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -func (h *finishedHash) Write(msg []byte) (n int, err error) { - h.client.Write(msg) - h.server.Write(msg) - - if h.version < VersionTLS12 { - h.clientMD5.Write(msg) - h.serverMD5.Write(msg) - } - - if h.buffer != nil { - h.buffer = append(h.buffer, msg...) - } - - return len(msg), nil -} - -func (h finishedHash) Sum() []byte { - if h.version >= VersionTLS12 { - return h.client.Sum(nil) - } - - out := make([]byte, 0, md5.Size+sha1.Size) - out = h.clientMD5.Sum(out) - return h.client.Sum(out) -} - -// clientSum returns the contents of the verify_data member of a client's -// Finished message. -func (h finishedHash) clientSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) - return out -} - -// serverSum returns the contents of the verify_data member of a server's -// Finished message. -func (h finishedHash) serverSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) - return out -} - -// hashForClientCertificate returns the handshake messages so far, pre-hashed if -// necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { - panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") - } - - if sigType == signatureEd25519 { - return h.buffer - } - - if h.version >= VersionTLS12 { - hash := hashAlg.New() - hash.Write(h.buffer) - return hash.Sum(nil) - } - - if sigType == signatureECDSA { - return h.server.Sum(nil) - } - - return h.Sum() -} - -// discardHandshakeBuffer is called when there is no more need to -// buffer the entirety of the handshake messages. -func (h *finishedHash) discardHandshakeBuffer() { - h.buffer = nil -} - -// noExportedKeyingMaterial is used as a value of -// ConnectionState.ekm when renegotiation is enabled and thus -// we wish to fail all key-material export requests. -func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { - return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") -} - -// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. -func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { - return func(label string, context []byte, length int) ([]byte, error) { - switch label { - case "client finished", "server finished", "master secret", "key expansion": - // These values are reserved and may not be used. - return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) - } - - seedLen := len(serverRandom) + len(clientRandom) - if context != nil { - seedLen += 2 + len(context) - } - seed := make([]byte, 0, seedLen) - - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - if context != nil { - if len(context) >= 1<<16 { - return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") - } - seed = append(seed, byte(len(context)>>8), byte(len(context))) - seed = append(seed, context...) - } - - keyMaterial := make([]byte, length) - prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) - return keyMaterial, nil - } -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go deleted file mode 100644 index 55fa01b3d61..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-18/unsafe.go +++ /dev/null @@ -1,96 +0,0 @@ -package qtls - -import ( - "crypto/tls" - "reflect" - "unsafe" -) - -func init() { - if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { - panic("qtls.ConnectionState doesn't match") - } - if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { - panic("qtls.ClientSessionState doesn't match") - } - if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { - panic("qtls.CertificateRequestInfo doesn't match") - } - if !structsEqual(&tls.Config{}, &config{}) { - panic("qtls.Config doesn't match") - } - if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { - panic("qtls.ClientHelloInfo doesn't match") - } -} - -func toConnectionState(c connectionState) ConnectionState { - return *(*ConnectionState)(unsafe.Pointer(&c)) -} - -func toClientSessionState(s *clientSessionState) *ClientSessionState { - return (*ClientSessionState)(unsafe.Pointer(s)) -} - -func fromClientSessionState(s *ClientSessionState) *clientSessionState { - return (*clientSessionState)(unsafe.Pointer(s)) -} - -func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { - return (*CertificateRequestInfo)(unsafe.Pointer(i)) -} - -func toConfig(c *config) *Config { - return (*Config)(unsafe.Pointer(c)) -} - -func fromConfig(c *Config) *config { - return (*config)(unsafe.Pointer(c)) -} - -func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { - return (*ClientHelloInfo)(unsafe.Pointer(chi)) -} - -func structsEqual(a, b interface{}) bool { - return compare(reflect.ValueOf(a), reflect.ValueOf(b)) -} - -func compare(a, b reflect.Value) bool { - sa := a.Elem() - sb := b.Elem() - if sa.NumField() != sb.NumField() { - return false - } - for i := 0; i < sa.NumField(); i++ { - fa := sa.Type().Field(i) - fb := sb.Type().Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - if fa.Type.Kind() != fb.Type.Kind() { - return false - } - if fa.Type.Kind() == reflect.Slice { - if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { - return false - } - continue - } - return false - } - } - return true -} - -func compareStruct(a, b reflect.Type) bool { - if a.NumField() != b.NumField() { - return false - } - for i := 0; i < a.NumField(); i++ { - fa := a.Field(i) - fb := b.Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - return false - } - } - return true -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE b/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE deleted file mode 100644 index 6a66aea5eaf..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/README.md b/vendor/github.com/marten-seemann/qtls-go1-19/README.md deleted file mode 100644 index db260ba2a09..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# qtls - -[![Go Reference](https://pkg.go.dev/badge/github.com/marten-seemann/qtls-go1-19.svg)](https://pkg.go.dev/github.com/marten-seemann/qtls-go1-19) -[![.github/workflows/go-test.yml](https://github.com/marten-seemann/qtls-go1-19/actions/workflows/go-test.yml/badge.svg)](https://github.com/marten-seemann/qtls-go1-19/actions/workflows/go-test.yml) - -This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/alert.go b/vendor/github.com/marten-seemann/qtls-go1-19/alert.go deleted file mode 100644 index 3feac79be84..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/alert.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import "strconv" - -type alert uint8 - -// Alert is a TLS alert -type Alert = alert - -const ( - // alert level - alertLevelWarning = 1 - alertLevelError = 2 -) - -const ( - alertCloseNotify alert = 0 - alertUnexpectedMessage alert = 10 - alertBadRecordMAC alert = 20 - alertDecryptionFailed alert = 21 - alertRecordOverflow alert = 22 - alertDecompressionFailure alert = 30 - alertHandshakeFailure alert = 40 - alertBadCertificate alert = 42 - alertUnsupportedCertificate alert = 43 - alertCertificateRevoked alert = 44 - alertCertificateExpired alert = 45 - alertCertificateUnknown alert = 46 - alertIllegalParameter alert = 47 - alertUnknownCA alert = 48 - alertAccessDenied alert = 49 - alertDecodeError alert = 50 - alertDecryptError alert = 51 - alertExportRestriction alert = 60 - alertProtocolVersion alert = 70 - alertInsufficientSecurity alert = 71 - alertInternalError alert = 80 - alertInappropriateFallback alert = 86 - alertUserCanceled alert = 90 - alertNoRenegotiation alert = 100 - alertMissingExtension alert = 109 - alertUnsupportedExtension alert = 110 - alertCertificateUnobtainable alert = 111 - alertUnrecognizedName alert = 112 - alertBadCertificateStatusResponse alert = 113 - alertBadCertificateHashValue alert = 114 - alertUnknownPSKIdentity alert = 115 - alertCertificateRequired alert = 116 - alertNoApplicationProtocol alert = 120 -) - -var alertText = map[alert]string{ - alertCloseNotify: "close notify", - alertUnexpectedMessage: "unexpected message", - alertBadRecordMAC: "bad record MAC", - alertDecryptionFailed: "decryption failed", - alertRecordOverflow: "record overflow", - alertDecompressionFailure: "decompression failure", - alertHandshakeFailure: "handshake failure", - alertBadCertificate: "bad certificate", - alertUnsupportedCertificate: "unsupported certificate", - alertCertificateRevoked: "revoked certificate", - alertCertificateExpired: "expired certificate", - alertCertificateUnknown: "unknown certificate", - alertIllegalParameter: "illegal parameter", - alertUnknownCA: "unknown certificate authority", - alertAccessDenied: "access denied", - alertDecodeError: "error decoding message", - alertDecryptError: "error decrypting message", - alertExportRestriction: "export restriction", - alertProtocolVersion: "protocol version not supported", - alertInsufficientSecurity: "insufficient security level", - alertInternalError: "internal error", - alertInappropriateFallback: "inappropriate fallback", - alertUserCanceled: "user canceled", - alertNoRenegotiation: "no renegotiation", - alertMissingExtension: "missing extension", - alertUnsupportedExtension: "unsupported extension", - alertCertificateUnobtainable: "certificate unobtainable", - alertUnrecognizedName: "unrecognized name", - alertBadCertificateStatusResponse: "bad certificate status response", - alertBadCertificateHashValue: "bad certificate hash value", - alertUnknownPSKIdentity: "unknown PSK identity", - alertCertificateRequired: "certificate required", - alertNoApplicationProtocol: "no application protocol", -} - -func (e alert) String() string { - s, ok := alertText[e] - if ok { - return "tls: " + s - } - return "tls: alert(" + strconv.Itoa(int(e)) + ")" -} - -func (e alert) Error() string { - return e.String() -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go b/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go deleted file mode 100644 index 12194508790..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/cpu.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build !js -// +build !js - -package qtls - -import ( - "runtime" - - "golang.org/x/sys/cpu" -) - -var ( - hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ - hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && - (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 || - runtime.GOARCH == "arm64" && hasGCMAsmARM64 || - runtime.GOARCH == "s390x" && hasGCMAsmS390X -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go b/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go deleted file mode 100644 index 33f7d21942f..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/cpu_other.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build js -// +build js - -package qtls - -var ( - hasGCMAsmAMD64 = false - hasGCMAsmARM64 = false - hasGCMAsmS390X = false - - hasAESGCMHardwareSupport = false -) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go b/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go deleted file mode 100644 index da13904a6e8..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/key_schedule.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto/elliptic" - "crypto/hmac" - "errors" - "hash" - "io" - "math/big" - - "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" -) - -// This file contains the functions necessary to compute the TLS 1.3 key -// schedule. See RFC 8446, Section 7. - -const ( - resumptionBinderLabel = "res binder" - clientHandshakeTrafficLabel = "c hs traffic" - serverHandshakeTrafficLabel = "s hs traffic" - clientApplicationTrafficLabel = "c ap traffic" - serverApplicationTrafficLabel = "s ap traffic" - exporterLabel = "exp master" - resumptionLabel = "res master" - trafficUpdateLabel = "traffic upd" -) - -// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { - var hkdfLabel cryptobyte.Builder - hkdfLabel.AddUint16(uint16(length)) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte("tls13 ")) - b.AddBytes([]byte(label)) - }) - hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(context) - }) - out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) - if err != nil || n != length { - panic("tls: HKDF-Expand-Label invocation failed unexpectedly") - } - return out -} - -// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. -func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { - if transcript == nil { - transcript = c.hash.New() - } - return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) -} - -// extract implements HKDF-Extract with the cipher suite hash. -func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { - if newSecret == nil { - newSecret = make([]byte, c.hash.Size()) - } - return hkdf.Extract(c.hash.New, newSecret, currentSecret) -} - -// nextTrafficSecret generates the next traffic secret, given the current one, -// according to RFC 8446, Section 7.2. -func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { - return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) -} - -// trafficKey generates traffic keys according to RFC 8446, Section 7.3. -func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { - key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) - iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) - return -} - -// finishedHash generates the Finished verify_data or PskBinderEntry according -// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey -// selection. -func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte { - finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size()) - verifyData := hmac.New(c.hash.New, finishedKey) - verifyData.Write(transcript.Sum(nil)) - return verifyData.Sum(nil) -} - -// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to -// RFC 8446, Section 7.5. -func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { - expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) - return func(label string, context []byte, length int) ([]byte, error) { - secret := c.deriveSecret(expMasterSecret, label, nil) - h := c.hash.New() - h.Write(context) - return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil - } -} - -// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, -// according to RFC 8446, Section 4.2.8.2. -type ecdheParameters interface { - CurveID() CurveID - PublicKey() []byte - SharedKey(peerPublicKey []byte) []byte -} - -func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { - if curveID == X25519 { - privateKey := make([]byte, curve25519.ScalarSize) - if _, err := io.ReadFull(rand, privateKey); err != nil { - return nil, err - } - publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) - if err != nil { - return nil, err - } - return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil - } - - curve, ok := curveForCurveID(curveID) - if !ok { - return nil, errors.New("tls: internal error: unsupported curve") - } - - p := &nistParameters{curveID: curveID} - var err error - p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) - if err != nil { - return nil, err - } - return p, nil -} - -func curveForCurveID(id CurveID) (elliptic.Curve, bool) { - switch id { - case CurveP256: - return elliptic.P256(), true - case CurveP384: - return elliptic.P384(), true - case CurveP521: - return elliptic.P521(), true - default: - return nil, false - } -} - -type nistParameters struct { - privateKey []byte - x, y *big.Int // public key - curveID CurveID -} - -func (p *nistParameters) CurveID() CurveID { - return p.curveID -} - -func (p *nistParameters) PublicKey() []byte { - curve, _ := curveForCurveID(p.curveID) - return elliptic.Marshal(curve, p.x, p.y) -} - -func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { - curve, _ := curveForCurveID(p.curveID) - // Unmarshal also checks whether the given point is on the curve. - x, y := elliptic.Unmarshal(curve, peerPublicKey) - if x == nil { - return nil - } - - xShared, _ := curve.ScalarMult(x, y, p.privateKey) - sharedKey := make([]byte, (curve.Params().BitSize+7)/8) - return xShared.FillBytes(sharedKey) -} - -type x25519Parameters struct { - privateKey []byte - publicKey []byte -} - -func (p *x25519Parameters) CurveID() CurveID { - return X25519 -} - -func (p *x25519Parameters) PublicKey() []byte { - return p.publicKey[:] -} - -func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { - sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) - if err != nil { - return nil - } - return sharedKey -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/prf.go b/vendor/github.com/marten-seemann/qtls-go1-19/prf.go deleted file mode 100644 index 9eb0221a0c0..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/prf.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "crypto" - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "errors" - "fmt" - "hash" -) - -// Split a premaster secret in two as specified in RFC 4346, Section 5. -func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { - s1 = secret[0 : (len(secret)+1)/2] - s2 = secret[len(secret)/2:] - return -} - -// pHash implements the P_hash function, as defined in RFC 4346, Section 5. -func pHash(result, secret, seed []byte, hash func() hash.Hash) { - h := hmac.New(hash, secret) - h.Write(seed) - a := h.Sum(nil) - - j := 0 - for j < len(result) { - h.Reset() - h.Write(a) - h.Write(seed) - b := h.Sum(nil) - copy(result[j:], b) - j += len(b) - - h.Reset() - h.Write(a) - a = h.Sum(nil) - } -} - -// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5. -func prf10(result, secret, label, seed []byte) { - hashSHA1 := sha1.New - hashMD5 := md5.New - - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - s1, s2 := splitPreMasterSecret(secret) - pHash(result, s1, labelAndSeed, hashMD5) - result2 := make([]byte, len(result)) - pHash(result2, s2, labelAndSeed, hashSHA1) - - for i, b := range result2 { - result[i] ^= b - } -} - -// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5. -func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { - return func(result, secret, label, seed []byte) { - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) - - pHash(result, secret, labelAndSeed, hashFunc) - } -} - -const ( - masterSecretLength = 48 // Length of a master secret in TLS 1.1. - finishedVerifyLength = 12 // Length of verify_data in a Finished message. -) - -var masterSecretLabel = []byte("master secret") -var keyExpansionLabel = []byte("key expansion") -var clientFinishedLabel = []byte("client finished") -var serverFinishedLabel = []byte("server finished") - -func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { - switch version { - case VersionTLS10, VersionTLS11: - return prf10, crypto.Hash(0) - case VersionTLS12: - if suite.flags&suiteSHA384 != 0 { - return prf12(sha512.New384), crypto.SHA384 - } - return prf12(sha256.New), crypto.SHA256 - default: - panic("unknown version") - } -} - -func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { - prf, _ := prfAndHashForVersion(version, suite) - return prf -} - -// masterFromPreMasterSecret generates the master secret from the pre-master -// secret. See RFC 5246, Section 8.1. -func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { - seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - masterSecret := make([]byte, masterSecretLength) - prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed) - return masterSecret -} - -// keysFromMasterSecret generates the connection keys from the master -// secret, given the lengths of the MAC key, cipher key and IV, as defined in -// RFC 2246, Section 6.3. -func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { - seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) - seed = append(seed, serverRandom...) - seed = append(seed, clientRandom...) - - n := 2*macLen + 2*keyLen + 2*ivLen - keyMaterial := make([]byte, n) - prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed) - clientMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - serverMAC = keyMaterial[:macLen] - keyMaterial = keyMaterial[macLen:] - clientKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - serverKey = keyMaterial[:keyLen] - keyMaterial = keyMaterial[keyLen:] - clientIV = keyMaterial[:ivLen] - keyMaterial = keyMaterial[ivLen:] - serverIV = keyMaterial[:ivLen] - return -} - -func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { - var buffer []byte - if version >= VersionTLS12 { - buffer = []byte{} - } - - prf, hash := prfAndHashForVersion(version, cipherSuite) - if hash != 0 { - return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} - } - - return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} -} - -// A finishedHash calculates the hash of a set of handshake messages suitable -// for including in a Finished message. -type finishedHash struct { - client hash.Hash - server hash.Hash - - // Prior to TLS 1.2, an additional MD5 hash is required. - clientMD5 hash.Hash - serverMD5 hash.Hash - - // In TLS 1.2, a full buffer is sadly required. - buffer []byte - - version uint16 - prf func(result, secret, label, seed []byte) -} - -func (h *finishedHash) Write(msg []byte) (n int, err error) { - h.client.Write(msg) - h.server.Write(msg) - - if h.version < VersionTLS12 { - h.clientMD5.Write(msg) - h.serverMD5.Write(msg) - } - - if h.buffer != nil { - h.buffer = append(h.buffer, msg...) - } - - return len(msg), nil -} - -func (h finishedHash) Sum() []byte { - if h.version >= VersionTLS12 { - return h.client.Sum(nil) - } - - out := make([]byte, 0, md5.Size+sha1.Size) - out = h.clientMD5.Sum(out) - return h.client.Sum(out) -} - -// clientSum returns the contents of the verify_data member of a client's -// Finished message. -func (h finishedHash) clientSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) - return out -} - -// serverSum returns the contents of the verify_data member of a server's -// Finished message. -func (h finishedHash) serverSum(masterSecret []byte) []byte { - out := make([]byte, finishedVerifyLength) - h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) - return out -} - -// hashForClientCertificate returns the handshake messages so far, pre-hashed if -// necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { - panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") - } - - if sigType == signatureEd25519 { - return h.buffer - } - - if h.version >= VersionTLS12 { - hash := hashAlg.New() - hash.Write(h.buffer) - return hash.Sum(nil) - } - - if sigType == signatureECDSA { - return h.server.Sum(nil) - } - - return h.Sum() -} - -// discardHandshakeBuffer is called when there is no more need to -// buffer the entirety of the handshake messages. -func (h *finishedHash) discardHandshakeBuffer() { - h.buffer = nil -} - -// noExportedKeyingMaterial is used as a value of -// ConnectionState.ekm when renegotiation is enabled and thus -// we wish to fail all key-material export requests. -func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) { - return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") -} - -// ekmFromMasterSecret generates exported keying material as defined in RFC 5705. -func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { - return func(label string, context []byte, length int) ([]byte, error) { - switch label { - case "client finished", "server finished", "master secret", "key expansion": - // These values are reserved and may not be used. - return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label) - } - - seedLen := len(serverRandom) + len(clientRandom) - if context != nil { - seedLen += 2 + len(context) - } - seed := make([]byte, 0, seedLen) - - seed = append(seed, clientRandom...) - seed = append(seed, serverRandom...) - - if context != nil { - if len(context) >= 1<<16 { - return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long") - } - seed = append(seed, byte(len(context)>>8), byte(len(context))) - seed = append(seed, context...) - } - - keyMaterial := make([]byte, length) - prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed) - return keyMaterial, nil - } -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go b/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go deleted file mode 100644 index 81e8a52eac6..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/ticket.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package qtls - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/sha256" - "crypto/subtle" - "encoding/binary" - "errors" - "io" - "time" - - "golang.org/x/crypto/cryptobyte" -) - -// sessionState contains the information that is serialized into a session -// ticket in order to later resume a connection. -type sessionState struct { - vers uint16 - cipherSuite uint16 - createdAt uint64 - masterSecret []byte // opaque master_secret<1..2^16-1>; - // struct { opaque certificate<1..2^24-1> } Certificate; - certificates [][]byte // Certificate certificate_list<0..2^24-1>; - - // usedOldKey is true if the ticket from which this session came from - // was encrypted with an older key and thus should be refreshed. - usedOldKey bool -} - -func (m *sessionState) marshal() []byte { - var b cryptobyte.Builder - b.AddUint16(m.vers) - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.masterSecret) - }) - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - for _, cert := range m.certificates { - b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(cert) - }) - } - }) - return b.BytesOrPanic() -} - -func (m *sessionState) unmarshal(data []byte) bool { - *m = sessionState{usedOldKey: m.usedOldKey} - s := cryptobyte.String(data) - if ok := s.ReadUint16(&m.vers) && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint16LengthPrefixed(&s, &m.masterSecret) && - len(m.masterSecret) != 0; !ok { - return false - } - var certList cryptobyte.String - if !s.ReadUint24LengthPrefixed(&certList) { - return false - } - for !certList.Empty() { - var cert []byte - if !readUint24LengthPrefixed(&certList, &cert) { - return false - } - m.certificates = append(m.certificates, cert) - } - return s.Empty() -} - -// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first -// version (revision = 0) doesn't carry any of the information needed for 0-RTT -// validation and the nonce is always empty. -// version (revision = 1) carries the max_early_data_size sent in the ticket. -// version (revision = 2) carries the ALPN sent in the ticket. -type sessionStateTLS13 struct { - // uint8 version = 0x0304; - // uint8 revision = 2; - cipherSuite uint16 - createdAt uint64 - resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>; - certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; - maxEarlyData uint32 - alpn string - - appData []byte -} - -func (m *sessionStateTLS13) marshal() []byte { - var b cryptobyte.Builder - b.AddUint16(VersionTLS13) - b.AddUint8(2) // revision - b.AddUint16(m.cipherSuite) - addUint64(&b, m.createdAt) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.resumptionSecret) - }) - marshalCertificate(&b, m.certificate) - b.AddUint32(m.maxEarlyData) - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpn)) - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.appData) - }) - return b.BytesOrPanic() -} - -func (m *sessionStateTLS13) unmarshal(data []byte) bool { - *m = sessionStateTLS13{} - s := cryptobyte.String(data) - var version uint16 - var revision uint8 - var alpn []byte - ret := s.ReadUint16(&version) && - version == VersionTLS13 && - s.ReadUint8(&revision) && - revision == 2 && - s.ReadUint16(&m.cipherSuite) && - readUint64(&s, &m.createdAt) && - readUint8LengthPrefixed(&s, &m.resumptionSecret) && - len(m.resumptionSecret) != 0 && - unmarshalCertificate(&s, &m.certificate) && - s.ReadUint32(&m.maxEarlyData) && - readUint8LengthPrefixed(&s, &alpn) && - readUint16LengthPrefixed(&s, &m.appData) && - s.Empty() - m.alpn = string(alpn) - return ret -} - -func (c *Conn) encryptTicket(state []byte) ([]byte, error) { - if len(c.ticketKeys) == 0 { - return nil, errors.New("tls: internal error: session ticket keys unavailable") - } - - encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size) - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - - if _, err := io.ReadFull(c.config.rand(), iv); err != nil { - return nil, err - } - key := c.ticketKeys[0] - copy(keyName, key.keyName[:]) - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) - } - cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state) - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - mac.Sum(macBytes[:0]) - - return encrypted, nil -} - -func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) { - if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { - return nil, false - } - - keyName := encrypted[:ticketKeyNameLen] - iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] - macBytes := encrypted[len(encrypted)-sha256.Size:] - ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] - - keyIndex := -1 - for i, candidateKey := range c.ticketKeys { - if bytes.Equal(keyName, candidateKey.keyName[:]) { - keyIndex = i - break - } - } - if keyIndex == -1 { - return nil, false - } - key := &c.ticketKeys[keyIndex] - - mac := hmac.New(sha256.New, key.hmacKey[:]) - mac.Write(encrypted[:len(encrypted)-sha256.Size]) - expected := mac.Sum(nil) - - if subtle.ConstantTimeCompare(macBytes, expected) != 1 { - return nil, false - } - - block, err := aes.NewCipher(key.aesKey[:]) - if err != nil { - return nil, false - } - plaintext = make([]byte, len(ciphertext)) - cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) - - return plaintext, keyIndex > 0 -} - -func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) { - m := new(newSessionTicketMsgTLS13) - - var certsFromClient [][]byte - for _, cert := range c.peerCertificates { - certsFromClient = append(certsFromClient, cert.Raw) - } - state := sessionStateTLS13{ - cipherSuite: c.cipherSuite, - createdAt: uint64(c.config.time().Unix()), - resumptionSecret: c.resumptionSecret, - certificate: Certificate{ - Certificate: certsFromClient, - OCSPStaple: c.ocspResponse, - SignedCertificateTimestamps: c.scts, - }, - appData: appData, - alpn: c.clientProtocol, - } - if c.extraConfig != nil { - state.maxEarlyData = c.extraConfig.MaxEarlyData - } - var err error - m.label, err = c.encryptTicket(state.marshal()) - if err != nil { - return nil, err - } - m.lifetime = uint32(maxSessionTicketLifetime / time.Second) - - // ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1 - // The value is not stored anywhere; we never need to check the ticket age - // because 0-RTT is not supported. - ageAdd := make([]byte, 4) - _, err = c.config.rand().Read(ageAdd) - if err != nil { - return nil, err - } - m.ageAdd = binary.LittleEndian.Uint32(ageAdd) - - // ticket_nonce, which must be unique per connection, is always left at - // zero because we only ever send one ticket per connection. - - if c.extraConfig != nil { - m.maxEarlyData = c.extraConfig.MaxEarlyData - } - return m, nil -} - -// GetSessionTicket generates a new session ticket. -// It should only be called after the handshake completes. -// It can only be used for servers, and only if the alternative record layer is set. -// The ticket may be nil if config.SessionTicketsDisabled is set, -// or if the client isn't able to receive session tickets. -func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { - if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { - return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") - } - if c.config.SessionTicketsDisabled { - return nil, nil - } - - m, err := c.getSessionTicketMsg(appData) - if err != nil { - return nil, err - } - return m.marshal(), nil -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/tls.go b/vendor/github.com/marten-seemann/qtls-go1-19/tls.go deleted file mode 100644 index 42207c235fd..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/tls.go +++ /dev/null @@ -1,362 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// package qtls partially implements TLS 1.2, as specified in RFC 5246, -// and TLS 1.3, as specified in RFC 8446. -package qtls - -// BUG(agl): The crypto/tls package only implements some countermeasures -// against Lucky13 attacks on CBC-mode encryption, and only on SHA1 -// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "net" - "os" - "strings" -) - -// Server returns a new TLS server side connection -// using conn as the underlying transport. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - extraConfig: extraConfig, - } - c.handshakeFn = c.serverHandshake - return c -} - -// Client returns a new TLS client side connection -// using conn as the underlying transport. -// The config cannot be nil: users must set either ServerName or -// InsecureSkipVerify in the config. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - c := &Conn{ - conn: conn, - config: fromConfig(config), - extraConfig: extraConfig, - isClient: true, - } - c.handshakeFn = c.clientHandshake - return c -} - -// A listener implements a network listener (net.Listener) for TLS connections. -type listener struct { - net.Listener - config *Config - extraConfig *ExtraConfig -} - -// Accept waits for and returns the next incoming TLS connection. -// The returned connection is of type *Conn. -func (l *listener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return Server(c, l.config, l.extraConfig), nil -} - -// NewListener creates a Listener which accepts connections from an inner -// Listener and wraps each connection with Server. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener { - l := new(listener) - l.Listener = inner - l.config = config - l.extraConfig = extraConfig - return l -} - -// Listen creates a TLS listener accepting connections on the -// given network address using net.Listen. -// The configuration config must be non-nil and must include -// at least one certificate or else set GetCertificate. -func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) { - if config == nil || len(config.Certificates) == 0 && - config.GetCertificate == nil && config.GetConfigForClient == nil { - return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") - } - l, err := net.Listen(network, laddr) - if err != nil { - return nil, err - } - return NewListener(l, config, extraConfig), nil -} - -type timeoutError struct{} - -func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -// DialWithDialer connects to the given network address using dialer.Dial and -// then initiates a TLS handshake, returning the resulting TLS connection. Any -// timeout or deadline given in the dialer apply to connection and TLS -// handshake as a whole. -// -// DialWithDialer interprets a nil configuration as equivalent to the zero -// configuration; see the documentation of Config for the defaults. -// -// DialWithDialer uses context.Background internally; to specify the context, -// use Dialer.DialContext with NetDialer set to the desired dialer. -func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - return dial(context.Background(), dialer, network, addr, config, extraConfig) -} - -func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - if netDialer.Timeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) - defer cancel() - } - - if !netDialer.Deadline.IsZero() { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) - defer cancel() - } - - rawConn, err := netDialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - - if config == nil { - config = defaultConfig() - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - if config.ServerName == "" { - // Make a copy to avoid polluting argument or default. - c := config.Clone() - c.ServerName = hostname - config = c - } - - conn := Client(rawConn, config, extraConfig) - if err := conn.HandshakeContext(ctx); err != nil { - rawConn.Close() - return nil, err - } - return conn, nil -} - -// Dial connects to the given network address using net.Dial -// and then initiates a TLS handshake, returning the resulting -// TLS connection. -// Dial interprets a nil configuration as equivalent to -// the zero configuration; see the documentation of Config -// for the defaults. -func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) { - return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig) -} - -// Dialer dials TLS connections given a configuration and a Dialer for the -// underlying connection. -type Dialer struct { - // NetDialer is the optional dialer to use for the TLS connections' - // underlying TCP connections. - // A nil NetDialer is equivalent to the net.Dialer zero value. - NetDialer *net.Dialer - - // Config is the TLS configuration to use for new connections. - // A nil configuration is equivalent to the zero - // configuration; see the documentation of Config for the - // defaults. - Config *Config - - ExtraConfig *ExtraConfig -} - -// Dial connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The returned Conn, if any, will always be of type *Conn. -// -// Dial uses context.Background internally; to specify the context, -// use DialContext. -func (d *Dialer) Dial(network, addr string) (net.Conn, error) { - return d.DialContext(context.Background(), network, addr) -} - -func (d *Dialer) netDialer() *net.Dialer { - if d.NetDialer != nil { - return d.NetDialer - } - return new(net.Dialer) -} - -// DialContext connects to the given network address and initiates a TLS -// handshake, returning the resulting TLS connection. -// -// The provided Context must be non-nil. If the context expires before -// the connection is complete, an error is returned. Once successfully -// connected, any expiration of the context will not affect the -// connection. -// -// The returned Conn, if any, will always be of type *Conn. -func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig) - if err != nil { - // Don't return c (a typed nil) in an interface. - return nil, err - } - return c, nil -} - -// LoadX509KeyPair reads and parses a public/private key pair from a pair -// of files. The files must contain PEM encoded data. The certificate file -// may contain intermediate certificates following the leaf certificate to -// form a certificate chain. On successful return, Certificate.Leaf will -// be nil because the parsed form of the certificate is not retained. -func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := os.ReadFile(certFile) - if err != nil { - return Certificate{}, err - } - keyPEMBlock, err := os.ReadFile(keyFile) - if err != nil { - return Certificate{}, err - } - return X509KeyPair(certPEMBlock, keyPEMBlock) -} - -// X509KeyPair parses a public/private key pair from a pair of -// PEM encoded data. On successful return, Certificate.Leaf will be nil because -// the parsed form of the certificate is not retained. -func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { - fail := func(err error) (Certificate, error) { return Certificate{}, err } - - var cert Certificate - var skippedBlockTypes []string - for { - var certDERBlock *pem.Block - certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) - if certDERBlock == nil { - break - } - if certDERBlock.Type == "CERTIFICATE" { - cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) - } else { - skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) - } - } - - if len(cert.Certificate) == 0 { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in certificate input")) - } - if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { - return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) - } - return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - - skippedBlockTypes = skippedBlockTypes[:0] - var keyDERBlock *pem.Block - for { - keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) - if keyDERBlock == nil { - if len(skippedBlockTypes) == 0 { - return fail(errors.New("tls: failed to find any PEM data in key input")) - } - if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { - return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) - } - return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) - } - if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { - break - } - skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) - } - - // We don't need to parse the public key for TLS, but we so do anyway - // to check that it looks sane and matches the private key. - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return fail(err) - } - - cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) - if err != nil { - return fail(err) - } - - switch pub := x509Cert.PublicKey.(type) { - case *rsa.PublicKey: - priv, ok := cert.PrivateKey.(*rsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.N.Cmp(priv.N) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case *ecdsa.PublicKey: - priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { - return fail(errors.New("tls: private key does not match public key")) - } - case ed25519.PublicKey: - priv, ok := cert.PrivateKey.(ed25519.PrivateKey) - if !ok { - return fail(errors.New("tls: private key type does not match public key type")) - } - if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { - return fail(errors.New("tls: private key does not match public key")) - } - default: - return fail(errors.New("tls: unknown public key algorithm")) - } - - return cert, nil -} - -// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. -// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. -func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { - if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: - return key, nil - default: - return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") - } - } - if key, err := x509.ParseECPrivateKey(der); err == nil { - return key, nil - } - - return nil, errors.New("tls: failed to parse private key") -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go b/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go deleted file mode 100644 index 55fa01b3d61..00000000000 --- a/vendor/github.com/marten-seemann/qtls-go1-19/unsafe.go +++ /dev/null @@ -1,96 +0,0 @@ -package qtls - -import ( - "crypto/tls" - "reflect" - "unsafe" -) - -func init() { - if !structsEqual(&tls.ConnectionState{}, &connectionState{}) { - panic("qtls.ConnectionState doesn't match") - } - if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) { - panic("qtls.ClientSessionState doesn't match") - } - if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) { - panic("qtls.CertificateRequestInfo doesn't match") - } - if !structsEqual(&tls.Config{}, &config{}) { - panic("qtls.Config doesn't match") - } - if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) { - panic("qtls.ClientHelloInfo doesn't match") - } -} - -func toConnectionState(c connectionState) ConnectionState { - return *(*ConnectionState)(unsafe.Pointer(&c)) -} - -func toClientSessionState(s *clientSessionState) *ClientSessionState { - return (*ClientSessionState)(unsafe.Pointer(s)) -} - -func fromClientSessionState(s *ClientSessionState) *clientSessionState { - return (*clientSessionState)(unsafe.Pointer(s)) -} - -func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo { - return (*CertificateRequestInfo)(unsafe.Pointer(i)) -} - -func toConfig(c *config) *Config { - return (*Config)(unsafe.Pointer(c)) -} - -func fromConfig(c *Config) *config { - return (*config)(unsafe.Pointer(c)) -} - -func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo { - return (*ClientHelloInfo)(unsafe.Pointer(chi)) -} - -func structsEqual(a, b interface{}) bool { - return compare(reflect.ValueOf(a), reflect.ValueOf(b)) -} - -func compare(a, b reflect.Value) bool { - sa := a.Elem() - sb := b.Elem() - if sa.NumField() != sb.NumField() { - return false - } - for i := 0; i < sa.NumField(); i++ { - fa := sa.Type().Field(i) - fb := sb.Type().Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - if fa.Type.Kind() != fb.Type.Kind() { - return false - } - if fa.Type.Kind() == reflect.Slice { - if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) { - return false - } - continue - } - return false - } - } - return true -} - -func compareStruct(a, b reflect.Type) bool { - if a.NumField() != b.NumField() { - return false - } - for i := 0; i < a.NumField(); i++ { - fa := a.Field(i) - fb := b.Field(i) - if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) { - return false - } - } - return true -} diff --git a/vendor/github.com/nxadm/tail/.gitignore b/vendor/github.com/nxadm/tail/.gitignore deleted file mode 100644 index 35d9351d385..00000000000 --- a/vendor/github.com/nxadm/tail/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -.idea/ -.test/ -examples/_* \ No newline at end of file diff --git a/vendor/github.com/nxadm/tail/CHANGES.md b/vendor/github.com/nxadm/tail/CHANGES.md deleted file mode 100644 index 224e54b44de..00000000000 --- a/vendor/github.com/nxadm/tail/CHANGES.md +++ /dev/null @@ -1,56 +0,0 @@ -# Version v1.4.7-v1.4.8 -* Documentation updates. -* Small linter cleanups. -* Added example in test. - -# Version v1.4.6 - -* Document the usage of Cleanup when re-reading a file (thanks to @lesovsky) for issue #18. -* Add example directories with example and tests for issues. - -# Version v1.4.4-v1.4.5 - -* Fix of checksum problem because of forced tag. No changes to the code. - -# Version v1.4.1 - -* Incorporated PR 162 by by Mohammed902: "Simplify non-Windows build tag". - -# Version v1.4.0 - -* Incorporated PR 9 by mschneider82: "Added seekinfo to Tail". - -# Version v1.3.1 - -* Incorporated PR 7: "Fix deadlock when stopping on non-empty file/buffer", -fixes upstream issue 93. - - -# Version v1.3.0 - -* Incorporated changes of unmerged upstream PR 149 by mezzi: "added line num -to Line struct". - -# Version v1.2.1 - -* Incorporated changes of unmerged upstream PR 128 by jadekler: "Compile-able -code in readme". -* Incorporated changes of unmerged upstream PR 130 by fgeller: "small change -to comment wording". -* Incorporated changes of unmerged upstream PR 133 by sm3142: "removed -spurious newlines from log messages". - -# Version v1.2.0 - -* Incorporated changes of unmerged upstream PR 126 by Code-Hex: "Solved the - problem for never return the last line if it's not followed by a newline". -* Incorporated changes of unmerged upstream PR 131 by StoicPerlman: "Remove -deprecated os.SEEK consts". The changes bumped the minimal supported Go -release to 1.9. - -# Version v1.1.0 - -* migration to go modules. -* release of master branch of the dormant upstream, because it contains -fixes and improvement no present in the tagged release. - diff --git a/vendor/github.com/nxadm/tail/Dockerfile b/vendor/github.com/nxadm/tail/Dockerfile deleted file mode 100644 index d9633891c00..00000000000 --- a/vendor/github.com/nxadm/tail/Dockerfile +++ /dev/null @@ -1,19 +0,0 @@ -FROM golang - -RUN mkdir -p $GOPATH/src/github.com/nxadm/tail/ -ADD . $GOPATH/src/github.com/nxadm/tail/ - -# expecting to fetch dependencies successfully. -RUN go get -v github.com/nxadm/tail - -# expecting to run the test successfully. -RUN go test -v github.com/nxadm/tail - -# expecting to install successfully -RUN go install -v github.com/nxadm/tail -RUN go install -v github.com/nxadm/tail/cmd/gotail - -RUN $GOPATH/bin/gotail -h || true - -ENV PATH $GOPATH/bin:$PATH -CMD ["gotail"] diff --git a/vendor/github.com/nxadm/tail/LICENSE b/vendor/github.com/nxadm/tail/LICENSE deleted file mode 100644 index 818d802a59a..00000000000 --- a/vendor/github.com/nxadm/tail/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -# The MIT License (MIT) - -# © Copyright 2015 Hewlett Packard Enterprise Development LP -Copyright (c) 2014 ActiveState - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/github.com/nxadm/tail/README.md b/vendor/github.com/nxadm/tail/README.md deleted file mode 100644 index f47939c749d..00000000000 --- a/vendor/github.com/nxadm/tail/README.md +++ /dev/null @@ -1,44 +0,0 @@ -![ci](https://github.com/nxadm/tail/workflows/ci/badge.svg)[![Go Reference](https://pkg.go.dev/badge/github.com/nxadm/tail.svg)](https://pkg.go.dev/github.com/nxadm/tail) - -# tail functionality in Go - -nxadm/tail provides a Go library that emulates the features of the BSD `tail` -program. The library comes with full support for truncation/move detection as -it is designed to work with log rotation tools. The library works on all -operating systems supported by Go, including POSIX systems like Linux and -*BSD, and MS Windows. Go 1.9 is the oldest compiler release supported. - -A simple example: - -```Go -// Create a tail -t, err := tail.TailFile( - "/var/log/nginx.log", tail.Config{Follow: true, ReOpen: true}) -if err != nil { - panic(err) -} - -// Print the text of each received line -for line := range t.Lines { - fmt.Println(line.Text) -} -``` - -See [API documentation](https://pkg.go.dev/github.com/nxadm/tail). - -## Installing - - go get github.com/nxadm/tail/... - -## History - -This project is an active, drop-in replacement for the -[abandoned](https://en.wikipedia.org/wiki/HPE_Helion) Go tail library at -[hpcloud](https://github.com/hpcloud/tail). Next to -[addressing open issues/PRs of the original project](https://github.com/nxadm/tail/issues/6), -nxadm/tail continues the development by keeping up to date with the Go toolchain -(e.g. go modules) and dependencies, completing the documentation, adding features -and fixing bugs. - -## Examples -Examples, e.g. used to debug an issue, are kept in the [examples directory](/examples). \ No newline at end of file diff --git a/vendor/github.com/nxadm/tail/ratelimiter/Licence b/vendor/github.com/nxadm/tail/ratelimiter/Licence deleted file mode 100644 index 434aab19f1a..00000000000 --- a/vendor/github.com/nxadm/tail/ratelimiter/Licence +++ /dev/null @@ -1,7 +0,0 @@ -Copyright (C) 2013 99designs - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go b/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go deleted file mode 100644 index 358b69e7f5c..00000000000 --- a/vendor/github.com/nxadm/tail/ratelimiter/leakybucket.go +++ /dev/null @@ -1,97 +0,0 @@ -// Package ratelimiter implements the Leaky Bucket ratelimiting algorithm with memcached and in-memory backends. -package ratelimiter - -import ( - "time" -) - -type LeakyBucket struct { - Size uint16 - Fill float64 - LeakInterval time.Duration // time.Duration for 1 unit of size to leak - Lastupdate time.Time - Now func() time.Time -} - -func NewLeakyBucket(size uint16, leakInterval time.Duration) *LeakyBucket { - bucket := LeakyBucket{ - Size: size, - Fill: 0, - LeakInterval: leakInterval, - Now: time.Now, - Lastupdate: time.Now(), - } - - return &bucket -} - -func (b *LeakyBucket) updateFill() { - now := b.Now() - if b.Fill > 0 { - elapsed := now.Sub(b.Lastupdate) - - b.Fill -= float64(elapsed) / float64(b.LeakInterval) - if b.Fill < 0 { - b.Fill = 0 - } - } - b.Lastupdate = now -} - -func (b *LeakyBucket) Pour(amount uint16) bool { - b.updateFill() - - var newfill float64 = b.Fill + float64(amount) - - if newfill > float64(b.Size) { - return false - } - - b.Fill = newfill - - return true -} - -// The time at which this bucket will be completely drained -func (b *LeakyBucket) DrainedAt() time.Time { - return b.Lastupdate.Add(time.Duration(b.Fill * float64(b.LeakInterval))) -} - -// The duration until this bucket is completely drained -func (b *LeakyBucket) TimeToDrain() time.Duration { - return b.DrainedAt().Sub(b.Now()) -} - -func (b *LeakyBucket) TimeSinceLastUpdate() time.Duration { - return b.Now().Sub(b.Lastupdate) -} - -type LeakyBucketSer struct { - Size uint16 - Fill float64 - LeakInterval time.Duration // time.Duration for 1 unit of size to leak - Lastupdate time.Time -} - -func (b *LeakyBucket) Serialise() *LeakyBucketSer { - bucket := LeakyBucketSer{ - Size: b.Size, - Fill: b.Fill, - LeakInterval: b.LeakInterval, - Lastupdate: b.Lastupdate, - } - - return &bucket -} - -func (b *LeakyBucketSer) DeSerialise() *LeakyBucket { - bucket := LeakyBucket{ - Size: b.Size, - Fill: b.Fill, - LeakInterval: b.LeakInterval, - Lastupdate: b.Lastupdate, - Now: time.Now, - } - - return &bucket -} diff --git a/vendor/github.com/nxadm/tail/ratelimiter/memory.go b/vendor/github.com/nxadm/tail/ratelimiter/memory.go deleted file mode 100644 index bf3c2131b1e..00000000000 --- a/vendor/github.com/nxadm/tail/ratelimiter/memory.go +++ /dev/null @@ -1,60 +0,0 @@ -package ratelimiter - -import ( - "errors" - "time" -) - -const ( - GC_SIZE int = 100 - GC_PERIOD time.Duration = 60 * time.Second -) - -type Memory struct { - store map[string]LeakyBucket - lastGCCollected time.Time -} - -func NewMemory() *Memory { - m := new(Memory) - m.store = make(map[string]LeakyBucket) - m.lastGCCollected = time.Now() - return m -} - -func (m *Memory) GetBucketFor(key string) (*LeakyBucket, error) { - - bucket, ok := m.store[key] - if !ok { - return nil, errors.New("miss") - } - - return &bucket, nil -} - -func (m *Memory) SetBucketFor(key string, bucket LeakyBucket) error { - - if len(m.store) > GC_SIZE { - m.GarbageCollect() - } - - m.store[key] = bucket - - return nil -} - -func (m *Memory) GarbageCollect() { - now := time.Now() - - // rate limit GC to once per minute - if now.Unix() >= m.lastGCCollected.Add(GC_PERIOD).Unix() { - for key, bucket := range m.store { - // if the bucket is drained, then GC - if bucket.DrainedAt().Unix() < now.Unix() { - delete(m.store, key) - } - } - - m.lastGCCollected = now - } -} diff --git a/vendor/github.com/nxadm/tail/ratelimiter/storage.go b/vendor/github.com/nxadm/tail/ratelimiter/storage.go deleted file mode 100644 index 89b2fe8826e..00000000000 --- a/vendor/github.com/nxadm/tail/ratelimiter/storage.go +++ /dev/null @@ -1,6 +0,0 @@ -package ratelimiter - -type Storage interface { - GetBucketFor(string) (*LeakyBucket, error) - SetBucketFor(string, LeakyBucket) error -} diff --git a/vendor/github.com/nxadm/tail/tail.go b/vendor/github.com/nxadm/tail/tail.go deleted file mode 100644 index 37ea4411e9d..00000000000 --- a/vendor/github.com/nxadm/tail/tail.go +++ /dev/null @@ -1,455 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -//nxadm/tail provides a Go library that emulates the features of the BSD `tail` -//program. The library comes with full support for truncation/move detection as -//it is designed to work with log rotation tools. The library works on all -//operating systems supported by Go, including POSIX systems like Linux and -//*BSD, and MS Windows. Go 1.9 is the oldest compiler release supported. -package tail - -import ( - "bufio" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "strings" - "sync" - "time" - - "github.com/nxadm/tail/ratelimiter" - "github.com/nxadm/tail/util" - "github.com/nxadm/tail/watch" - "gopkg.in/tomb.v1" -) - -var ( - // ErrStop is returned when the tail of a file has been marked to be stopped. - ErrStop = errors.New("tail should now stop") -) - -type Line struct { - Text string // The contents of the file - Num int // The line number - SeekInfo SeekInfo // SeekInfo - Time time.Time // Present time - Err error // Error from tail -} - -// Deprecated: this function is no longer used internally and it has little of no -// use in the API. As such, it will be removed from the API in a future major -// release. -// -// NewLine returns a * pointer to a Line struct. -func NewLine(text string, lineNum int) *Line { - return &Line{text, lineNum, SeekInfo{}, time.Now(), nil} -} - -// SeekInfo represents arguments to io.Seek. See: https://golang.org/pkg/io/#SectionReader.Seek -type SeekInfo struct { - Offset int64 - Whence int -} - -type logger interface { - Fatal(v ...interface{}) - Fatalf(format string, v ...interface{}) - Fatalln(v ...interface{}) - Panic(v ...interface{}) - Panicf(format string, v ...interface{}) - Panicln(v ...interface{}) - Print(v ...interface{}) - Printf(format string, v ...interface{}) - Println(v ...interface{}) -} - -// Config is used to specify how a file must be tailed. -type Config struct { - // File-specifc - Location *SeekInfo // Tail from this location. If nil, start at the beginning of the file - ReOpen bool // Reopen recreated files (tail -F) - MustExist bool // Fail early if the file does not exist - Poll bool // Poll for file changes instead of using the default inotify - Pipe bool // The file is a named pipe (mkfifo) - - // Generic IO - Follow bool // Continue looking for new lines (tail -f) - MaxLineSize int // If non-zero, split longer lines into multiple lines - - // Optionally, use a ratelimiter (e.g. created by the ratelimiter/NewLeakyBucket function) - RateLimiter *ratelimiter.LeakyBucket - - // Optionally use a Logger. When nil, the Logger is set to tail.DefaultLogger. - // To disable logging, set it to tail.DiscardingLogger - Logger logger -} - -type Tail struct { - Filename string // The filename - Lines chan *Line // A consumable channel of *Line - Config // Tail.Configuration - - file *os.File - reader *bufio.Reader - lineNum int - - watcher watch.FileWatcher - changes *watch.FileChanges - - tomb.Tomb // provides: Done, Kill, Dying - - lk sync.Mutex -} - -var ( - // DefaultLogger logs to os.Stderr and it is used when Config.Logger == nil - DefaultLogger = log.New(os.Stderr, "", log.LstdFlags) - // DiscardingLogger can be used to disable logging output - DiscardingLogger = log.New(ioutil.Discard, "", 0) -) - -// TailFile begins tailing the file. And returns a pointer to a Tail struct -// and an error. An output stream is made available via the Tail.Lines -// channel (e.g. to be looped and printed). To handle errors during tailing, -// after finishing reading from the Lines channel, invoke the `Wait` or `Err` -// method on the returned *Tail. -func TailFile(filename string, config Config) (*Tail, error) { - if config.ReOpen && !config.Follow { - util.Fatal("cannot set ReOpen without Follow.") - } - - t := &Tail{ - Filename: filename, - Lines: make(chan *Line), - Config: config, - } - - // when Logger was not specified in config, use default logger - if t.Logger == nil { - t.Logger = DefaultLogger - } - - if t.Poll { - t.watcher = watch.NewPollingFileWatcher(filename) - } else { - t.watcher = watch.NewInotifyFileWatcher(filename) - } - - if t.MustExist { - var err error - t.file, err = OpenFile(t.Filename) - if err != nil { - return nil, err - } - } - - go t.tailFileSync() - - return t, nil -} - -// Tell returns the file's current position, like stdio's ftell() and an error. -// Beware that this value may not be completely accurate because one line from -// the chan(tail.Lines) may have been read already. -func (tail *Tail) Tell() (offset int64, err error) { - if tail.file == nil { - return - } - offset, err = tail.file.Seek(0, io.SeekCurrent) - if err != nil { - return - } - - tail.lk.Lock() - defer tail.lk.Unlock() - if tail.reader == nil { - return - } - - offset -= int64(tail.reader.Buffered()) - return -} - -// Stop stops the tailing activity. -func (tail *Tail) Stop() error { - tail.Kill(nil) - return tail.Wait() -} - -// StopAtEOF stops tailing as soon as the end of the file is reached. The function -// returns an error, -func (tail *Tail) StopAtEOF() error { - tail.Kill(errStopAtEOF) - return tail.Wait() -} - -var errStopAtEOF = errors.New("tail: stop at eof") - -func (tail *Tail) close() { - close(tail.Lines) - tail.closeFile() -} - -func (tail *Tail) closeFile() { - if tail.file != nil { - tail.file.Close() - tail.file = nil - } -} - -func (tail *Tail) reopen() error { - tail.closeFile() - tail.lineNum = 0 - for { - var err error - tail.file, err = OpenFile(tail.Filename) - if err != nil { - if os.IsNotExist(err) { - tail.Logger.Printf("Waiting for %s to appear...", tail.Filename) - if err := tail.watcher.BlockUntilExists(&tail.Tomb); err != nil { - if err == tomb.ErrDying { - return err - } - return fmt.Errorf("Failed to detect creation of %s: %s", tail.Filename, err) - } - continue - } - return fmt.Errorf("Unable to open file %s: %s", tail.Filename, err) - } - break - } - return nil -} - -func (tail *Tail) readLine() (string, error) { - tail.lk.Lock() - line, err := tail.reader.ReadString('\n') - tail.lk.Unlock() - if err != nil { - // Note ReadString "returns the data read before the error" in - // case of an error, including EOF, so we return it as is. The - // caller is expected to process it if err is EOF. - return line, err - } - - line = strings.TrimRight(line, "\n") - - return line, err -} - -func (tail *Tail) tailFileSync() { - defer tail.Done() - defer tail.close() - - if !tail.MustExist { - // deferred first open. - err := tail.reopen() - if err != nil { - if err != tomb.ErrDying { - tail.Kill(err) - } - return - } - } - - // Seek to requested location on first open of the file. - if tail.Location != nil { - _, err := tail.file.Seek(tail.Location.Offset, tail.Location.Whence) - if err != nil { - tail.Killf("Seek error on %s: %s", tail.Filename, err) - return - } - } - - tail.openReader() - - // Read line by line. - for { - // do not seek in named pipes - if !tail.Pipe { - // grab the position in case we need to back up in the event of a half-line - if _, err := tail.Tell(); err != nil { - tail.Kill(err) - return - } - } - - line, err := tail.readLine() - - // Process `line` even if err is EOF. - if err == nil { - cooloff := !tail.sendLine(line) - if cooloff { - // Wait a second before seeking till the end of - // file when rate limit is reached. - msg := ("Too much log activity; waiting a second before resuming tailing") - offset, _ := tail.Tell() - tail.Lines <- &Line{msg, tail.lineNum, SeekInfo{Offset: offset}, time.Now(), errors.New(msg)} - select { - case <-time.After(time.Second): - case <-tail.Dying(): - return - } - if err := tail.seekEnd(); err != nil { - tail.Kill(err) - return - } - } - } else if err == io.EOF { - if !tail.Follow { - if line != "" { - tail.sendLine(line) - } - return - } - - if tail.Follow && line != "" { - tail.sendLine(line) - if err := tail.seekEnd(); err != nil { - tail.Kill(err) - return - } - } - - // When EOF is reached, wait for more data to become - // available. Wait strategy is based on the `tail.watcher` - // implementation (inotify or polling). - err := tail.waitForChanges() - if err != nil { - if err != ErrStop { - tail.Kill(err) - } - return - } - } else { - // non-EOF error - tail.Killf("Error reading %s: %s", tail.Filename, err) - return - } - - select { - case <-tail.Dying(): - if tail.Err() == errStopAtEOF { - continue - } - return - default: - } - } -} - -// waitForChanges waits until the file has been appended, deleted, -// moved or truncated. When moved or deleted - the file will be -// reopened if ReOpen is true. Truncated files are always reopened. -func (tail *Tail) waitForChanges() error { - if tail.changes == nil { - pos, err := tail.file.Seek(0, io.SeekCurrent) - if err != nil { - return err - } - tail.changes, err = tail.watcher.ChangeEvents(&tail.Tomb, pos) - if err != nil { - return err - } - } - - select { - case <-tail.changes.Modified: - return nil - case <-tail.changes.Deleted: - tail.changes = nil - if tail.ReOpen { - // XXX: we must not log from a library. - tail.Logger.Printf("Re-opening moved/deleted file %s ...", tail.Filename) - if err := tail.reopen(); err != nil { - return err - } - tail.Logger.Printf("Successfully reopened %s", tail.Filename) - tail.openReader() - return nil - } - tail.Logger.Printf("Stopping tail as file no longer exists: %s", tail.Filename) - return ErrStop - case <-tail.changes.Truncated: - // Always reopen truncated files (Follow is true) - tail.Logger.Printf("Re-opening truncated file %s ...", tail.Filename) - if err := tail.reopen(); err != nil { - return err - } - tail.Logger.Printf("Successfully reopened truncated %s", tail.Filename) - tail.openReader() - return nil - case <-tail.Dying(): - return ErrStop - } -} - -func (tail *Tail) openReader() { - tail.lk.Lock() - if tail.MaxLineSize > 0 { - // add 2 to account for newline characters - tail.reader = bufio.NewReaderSize(tail.file, tail.MaxLineSize+2) - } else { - tail.reader = bufio.NewReader(tail.file) - } - tail.lk.Unlock() -} - -func (tail *Tail) seekEnd() error { - return tail.seekTo(SeekInfo{Offset: 0, Whence: io.SeekEnd}) -} - -func (tail *Tail) seekTo(pos SeekInfo) error { - _, err := tail.file.Seek(pos.Offset, pos.Whence) - if err != nil { - return fmt.Errorf("Seek error on %s: %s", tail.Filename, err) - } - // Reset the read buffer whenever the file is re-seek'ed - tail.reader.Reset(tail.file) - return nil -} - -// sendLine sends the line(s) to Lines channel, splitting longer lines -// if necessary. Return false if rate limit is reached. -func (tail *Tail) sendLine(line string) bool { - now := time.Now() - lines := []string{line} - - // Split longer lines - if tail.MaxLineSize > 0 && len(line) > tail.MaxLineSize { - lines = util.PartitionString(line, tail.MaxLineSize) - } - - for _, line := range lines { - tail.lineNum++ - offset, _ := tail.Tell() - select { - case tail.Lines <- &Line{line, tail.lineNum, SeekInfo{Offset: offset}, now, nil}: - case <-tail.Dying(): - return true - } - } - - if tail.Config.RateLimiter != nil { - ok := tail.Config.RateLimiter.Pour(uint16(len(lines))) - if !ok { - tail.Logger.Printf("Leaky bucket full (%v); entering 1s cooloff period.", - tail.Filename) - return false - } - } - - return true -} - -// Cleanup removes inotify watches added by the tail package. This function is -// meant to be invoked from a process's exit handler. Linux kernel may not -// automatically remove inotify watches after the process exits. -// If you plan to re-read a file, don't call Cleanup in between. -func (tail *Tail) Cleanup() { - watch.Cleanup(tail.Filename) -} diff --git a/vendor/github.com/nxadm/tail/tail_posix.go b/vendor/github.com/nxadm/tail/tail_posix.go deleted file mode 100644 index 23e071dea16..00000000000 --- a/vendor/github.com/nxadm/tail/tail_posix.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// +build !windows - -package tail - -import ( - "os" -) - -// Deprecated: this function is only useful internally and, as such, -// it will be removed from the API in a future major release. -// -// OpenFile proxies a os.Open call for a file so it can be correctly tailed -// on POSIX and non-POSIX OSes like MS Windows. -func OpenFile(name string) (file *os.File, err error) { - return os.Open(name) -} diff --git a/vendor/github.com/nxadm/tail/tail_windows.go b/vendor/github.com/nxadm/tail/tail_windows.go deleted file mode 100644 index da0d2f39c97..00000000000 --- a/vendor/github.com/nxadm/tail/tail_windows.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// +build windows - -package tail - -import ( - "os" - - "github.com/nxadm/tail/winfile" -) - -// Deprecated: this function is only useful internally and, as such, -// it will be removed from the API in a future major release. -// -// OpenFile proxies a os.Open call for a file so it can be correctly tailed -// on POSIX and non-POSIX OSes like MS Windows. -func OpenFile(name string) (file *os.File, err error) { - return winfile.OpenFile(name, os.O_RDONLY, 0) -} diff --git a/vendor/github.com/nxadm/tail/util/util.go b/vendor/github.com/nxadm/tail/util/util.go deleted file mode 100644 index b64caa21269..00000000000 --- a/vendor/github.com/nxadm/tail/util/util.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -package util - -import ( - "fmt" - "log" - "os" - "runtime/debug" -) - -type Logger struct { - *log.Logger -} - -var LOGGER = &Logger{log.New(os.Stderr, "", log.LstdFlags)} - -// fatal is like panic except it displays only the current goroutine's stack. -func Fatal(format string, v ...interface{}) { - // https://github.com/nxadm/log/blob/master/log.go#L45 - LOGGER.Output(2, fmt.Sprintf("FATAL -- "+format, v...)+"\n"+string(debug.Stack())) - os.Exit(1) -} - -// partitionString partitions the string into chunks of given size, -// with the last chunk of variable size. -func PartitionString(s string, chunkSize int) []string { - if chunkSize <= 0 { - panic("invalid chunkSize") - } - length := len(s) - chunks := 1 + length/chunkSize - start := 0 - end := chunkSize - parts := make([]string, 0, chunks) - for { - if end > length { - end = length - } - parts = append(parts, s[start:end]) - if end == length { - break - } - start, end = end, end+chunkSize - } - return parts -} diff --git a/vendor/github.com/nxadm/tail/watch/filechanges.go b/vendor/github.com/nxadm/tail/watch/filechanges.go deleted file mode 100644 index 5b65f42ae58..00000000000 --- a/vendor/github.com/nxadm/tail/watch/filechanges.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -package watch - -type FileChanges struct { - Modified chan bool // Channel to get notified of modifications - Truncated chan bool // Channel to get notified of truncations - Deleted chan bool // Channel to get notified of deletions/renames -} - -func NewFileChanges() *FileChanges { - return &FileChanges{ - make(chan bool, 1), make(chan bool, 1), make(chan bool, 1)} -} - -func (fc *FileChanges) NotifyModified() { - sendOnlyIfEmpty(fc.Modified) -} - -func (fc *FileChanges) NotifyTruncated() { - sendOnlyIfEmpty(fc.Truncated) -} - -func (fc *FileChanges) NotifyDeleted() { - sendOnlyIfEmpty(fc.Deleted) -} - -// sendOnlyIfEmpty sends on a bool channel only if the channel has no -// backlog to be read by other goroutines. This concurrency pattern -// can be used to notify other goroutines if and only if they are -// looking for it (i.e., subsequent notifications can be compressed -// into one). -func sendOnlyIfEmpty(ch chan bool) { - select { - case ch <- true: - default: - } -} diff --git a/vendor/github.com/nxadm/tail/watch/inotify.go b/vendor/github.com/nxadm/tail/watch/inotify.go deleted file mode 100644 index cbd11ad8d03..00000000000 --- a/vendor/github.com/nxadm/tail/watch/inotify.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -package watch - -import ( - "fmt" - "os" - "path/filepath" - - "github.com/nxadm/tail/util" - - "github.com/fsnotify/fsnotify" - "gopkg.in/tomb.v1" -) - -// InotifyFileWatcher uses inotify to monitor file changes. -type InotifyFileWatcher struct { - Filename string - Size int64 -} - -func NewInotifyFileWatcher(filename string) *InotifyFileWatcher { - fw := &InotifyFileWatcher{filepath.Clean(filename), 0} - return fw -} - -func (fw *InotifyFileWatcher) BlockUntilExists(t *tomb.Tomb) error { - err := WatchCreate(fw.Filename) - if err != nil { - return err - } - defer RemoveWatchCreate(fw.Filename) - - // Do a real check now as the file might have been created before - // calling `WatchFlags` above. - if _, err = os.Stat(fw.Filename); !os.IsNotExist(err) { - // file exists, or stat returned an error. - return err - } - - events := Events(fw.Filename) - - for { - select { - case evt, ok := <-events: - if !ok { - return fmt.Errorf("inotify watcher has been closed") - } - evtName, err := filepath.Abs(evt.Name) - if err != nil { - return err - } - fwFilename, err := filepath.Abs(fw.Filename) - if err != nil { - return err - } - if evtName == fwFilename { - return nil - } - case <-t.Dying(): - return tomb.ErrDying - } - } - panic("unreachable") -} - -func (fw *InotifyFileWatcher) ChangeEvents(t *tomb.Tomb, pos int64) (*FileChanges, error) { - err := Watch(fw.Filename) - if err != nil { - return nil, err - } - - changes := NewFileChanges() - fw.Size = pos - - go func() { - - events := Events(fw.Filename) - - for { - prevSize := fw.Size - - var evt fsnotify.Event - var ok bool - - select { - case evt, ok = <-events: - if !ok { - RemoveWatch(fw.Filename) - return - } - case <-t.Dying(): - RemoveWatch(fw.Filename) - return - } - - switch { - case evt.Op&fsnotify.Remove == fsnotify.Remove: - fallthrough - - case evt.Op&fsnotify.Rename == fsnotify.Rename: - RemoveWatch(fw.Filename) - changes.NotifyDeleted() - return - - //With an open fd, unlink(fd) - inotify returns IN_ATTRIB (==fsnotify.Chmod) - case evt.Op&fsnotify.Chmod == fsnotify.Chmod: - fallthrough - - case evt.Op&fsnotify.Write == fsnotify.Write: - fi, err := os.Stat(fw.Filename) - if err != nil { - if os.IsNotExist(err) { - RemoveWatch(fw.Filename) - changes.NotifyDeleted() - return - } - // XXX: report this error back to the user - util.Fatal("Failed to stat file %v: %v", fw.Filename, err) - } - fw.Size = fi.Size() - - if prevSize > 0 && prevSize > fw.Size { - changes.NotifyTruncated() - } else { - changes.NotifyModified() - } - prevSize = fw.Size - } - } - }() - - return changes, nil -} diff --git a/vendor/github.com/nxadm/tail/watch/inotify_tracker.go b/vendor/github.com/nxadm/tail/watch/inotify_tracker.go deleted file mode 100644 index cb9572a0302..00000000000 --- a/vendor/github.com/nxadm/tail/watch/inotify_tracker.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -package watch - -import ( - "log" - "os" - "path/filepath" - "sync" - "syscall" - - "github.com/nxadm/tail/util" - - "github.com/fsnotify/fsnotify" -) - -type InotifyTracker struct { - mux sync.Mutex - watcher *fsnotify.Watcher - chans map[string]chan fsnotify.Event - done map[string]chan bool - watchNums map[string]int - watch chan *watchInfo - remove chan *watchInfo - error chan error -} - -type watchInfo struct { - op fsnotify.Op - fname string -} - -func (this *watchInfo) isCreate() bool { - return this.op == fsnotify.Create -} - -var ( - // globally shared InotifyTracker; ensures only one fsnotify.Watcher is used - shared *InotifyTracker - - // these are used to ensure the shared InotifyTracker is run exactly once - once = sync.Once{} - goRun = func() { - shared = &InotifyTracker{ - mux: sync.Mutex{}, - chans: make(map[string]chan fsnotify.Event), - done: make(map[string]chan bool), - watchNums: make(map[string]int), - watch: make(chan *watchInfo), - remove: make(chan *watchInfo), - error: make(chan error), - } - go shared.run() - } - - logger = log.New(os.Stderr, "", log.LstdFlags) -) - -// Watch signals the run goroutine to begin watching the input filename -func Watch(fname string) error { - return watch(&watchInfo{ - fname: fname, - }) -} - -// Watch create signals the run goroutine to begin watching the input filename -// if call the WatchCreate function, don't call the Cleanup, call the RemoveWatchCreate -func WatchCreate(fname string) error { - return watch(&watchInfo{ - op: fsnotify.Create, - fname: fname, - }) -} - -func watch(winfo *watchInfo) error { - // start running the shared InotifyTracker if not already running - once.Do(goRun) - - winfo.fname = filepath.Clean(winfo.fname) - shared.watch <- winfo - return <-shared.error -} - -// RemoveWatch signals the run goroutine to remove the watch for the input filename -func RemoveWatch(fname string) error { - return remove(&watchInfo{ - fname: fname, - }) -} - -// RemoveWatch create signals the run goroutine to remove the watch for the input filename -func RemoveWatchCreate(fname string) error { - return remove(&watchInfo{ - op: fsnotify.Create, - fname: fname, - }) -} - -func remove(winfo *watchInfo) error { - // start running the shared InotifyTracker if not already running - once.Do(goRun) - - winfo.fname = filepath.Clean(winfo.fname) - shared.mux.Lock() - done := shared.done[winfo.fname] - if done != nil { - delete(shared.done, winfo.fname) - close(done) - } - shared.mux.Unlock() - - shared.remove <- winfo - return <-shared.error -} - -// Events returns a channel to which FileEvents corresponding to the input filename -// will be sent. This channel will be closed when removeWatch is called on this -// filename. -func Events(fname string) <-chan fsnotify.Event { - shared.mux.Lock() - defer shared.mux.Unlock() - - return shared.chans[fname] -} - -// Cleanup removes the watch for the input filename if necessary. -func Cleanup(fname string) error { - return RemoveWatch(fname) -} - -// watchFlags calls fsnotify.WatchFlags for the input filename and flags, creating -// a new Watcher if the previous Watcher was closed. -func (shared *InotifyTracker) addWatch(winfo *watchInfo) error { - shared.mux.Lock() - defer shared.mux.Unlock() - - if shared.chans[winfo.fname] == nil { - shared.chans[winfo.fname] = make(chan fsnotify.Event) - } - if shared.done[winfo.fname] == nil { - shared.done[winfo.fname] = make(chan bool) - } - - fname := winfo.fname - if winfo.isCreate() { - // Watch for new files to be created in the parent directory. - fname = filepath.Dir(fname) - } - - var err error - // already in inotify watch - if shared.watchNums[fname] == 0 { - err = shared.watcher.Add(fname) - } - if err == nil { - shared.watchNums[fname]++ - } - return err -} - -// removeWatch calls fsnotify.RemoveWatch for the input filename and closes the -// corresponding events channel. -func (shared *InotifyTracker) removeWatch(winfo *watchInfo) error { - shared.mux.Lock() - - ch := shared.chans[winfo.fname] - if ch != nil { - delete(shared.chans, winfo.fname) - close(ch) - } - - fname := winfo.fname - if winfo.isCreate() { - // Watch for new files to be created in the parent directory. - fname = filepath.Dir(fname) - } - shared.watchNums[fname]-- - watchNum := shared.watchNums[fname] - if watchNum == 0 { - delete(shared.watchNums, fname) - } - shared.mux.Unlock() - - var err error - // If we were the last ones to watch this file, unsubscribe from inotify. - // This needs to happen after releasing the lock because fsnotify waits - // synchronously for the kernel to acknowledge the removal of the watch - // for this file, which causes us to deadlock if we still held the lock. - if watchNum == 0 { - err = shared.watcher.Remove(fname) - } - - return err -} - -// sendEvent sends the input event to the appropriate Tail. -func (shared *InotifyTracker) sendEvent(event fsnotify.Event) { - name := filepath.Clean(event.Name) - - shared.mux.Lock() - ch := shared.chans[name] - done := shared.done[name] - shared.mux.Unlock() - - if ch != nil && done != nil { - select { - case ch <- event: - case <-done: - } - } -} - -// run starts the goroutine in which the shared struct reads events from its -// Watcher's Event channel and sends the events to the appropriate Tail. -func (shared *InotifyTracker) run() { - watcher, err := fsnotify.NewWatcher() - if err != nil { - util.Fatal("failed to create Watcher") - } - shared.watcher = watcher - - for { - select { - case winfo := <-shared.watch: - shared.error <- shared.addWatch(winfo) - - case winfo := <-shared.remove: - shared.error <- shared.removeWatch(winfo) - - case event, open := <-shared.watcher.Events: - if !open { - return - } - shared.sendEvent(event) - - case err, open := <-shared.watcher.Errors: - if !open { - return - } else if err != nil { - sysErr, ok := err.(*os.SyscallError) - if !ok || sysErr.Err != syscall.EINTR { - logger.Printf("Error in Watcher Error channel: %s", err) - } - } - } - } -} diff --git a/vendor/github.com/nxadm/tail/watch/polling.go b/vendor/github.com/nxadm/tail/watch/polling.go deleted file mode 100644 index 74e10aa427c..00000000000 --- a/vendor/github.com/nxadm/tail/watch/polling.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -package watch - -import ( - "os" - "runtime" - "time" - - "github.com/nxadm/tail/util" - "gopkg.in/tomb.v1" -) - -// PollingFileWatcher polls the file for changes. -type PollingFileWatcher struct { - Filename string - Size int64 -} - -func NewPollingFileWatcher(filename string) *PollingFileWatcher { - fw := &PollingFileWatcher{filename, 0} - return fw -} - -var POLL_DURATION time.Duration - -func (fw *PollingFileWatcher) BlockUntilExists(t *tomb.Tomb) error { - for { - if _, err := os.Stat(fw.Filename); err == nil { - return nil - } else if !os.IsNotExist(err) { - return err - } - select { - case <-time.After(POLL_DURATION): - continue - case <-t.Dying(): - return tomb.ErrDying - } - } - panic("unreachable") -} - -func (fw *PollingFileWatcher) ChangeEvents(t *tomb.Tomb, pos int64) (*FileChanges, error) { - origFi, err := os.Stat(fw.Filename) - if err != nil { - return nil, err - } - - changes := NewFileChanges() - var prevModTime time.Time - - // XXX: use tomb.Tomb to cleanly manage these goroutines. replace - // the fatal (below) with tomb's Kill. - - fw.Size = pos - - go func() { - prevSize := fw.Size - for { - select { - case <-t.Dying(): - return - default: - } - - time.Sleep(POLL_DURATION) - fi, err := os.Stat(fw.Filename) - if err != nil { - // Windows cannot delete a file if a handle is still open (tail keeps one open) - // so it gives access denied to anything trying to read it until all handles are released. - if os.IsNotExist(err) || (runtime.GOOS == "windows" && os.IsPermission(err)) { - // File does not exist (has been deleted). - changes.NotifyDeleted() - return - } - - // XXX: report this error back to the user - util.Fatal("Failed to stat file %v: %v", fw.Filename, err) - } - - // File got moved/renamed? - if !os.SameFile(origFi, fi) { - changes.NotifyDeleted() - return - } - - // File got truncated? - fw.Size = fi.Size() - if prevSize > 0 && prevSize > fw.Size { - changes.NotifyTruncated() - prevSize = fw.Size - continue - } - // File got bigger? - if prevSize > 0 && prevSize < fw.Size { - changes.NotifyModified() - prevSize = fw.Size - continue - } - prevSize = fw.Size - - // File was appended to (changed)? - modTime := fi.ModTime() - if modTime != prevModTime { - prevModTime = modTime - changes.NotifyModified() - } - } - }() - - return changes, nil -} - -func init() { - POLL_DURATION = 250 * time.Millisecond -} diff --git a/vendor/github.com/nxadm/tail/watch/watch.go b/vendor/github.com/nxadm/tail/watch/watch.go deleted file mode 100644 index 2b5112805a1..00000000000 --- a/vendor/github.com/nxadm/tail/watch/watch.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// Copyright (c) 2015 HPE Software Inc. All rights reserved. -// Copyright (c) 2013 ActiveState Software Inc. All rights reserved. - -package watch - -import "gopkg.in/tomb.v1" - -// FileWatcher monitors file-level events. -type FileWatcher interface { - // BlockUntilExists blocks until the file comes into existence. - BlockUntilExists(*tomb.Tomb) error - - // ChangeEvents reports on changes to a file, be it modification, - // deletion, renames or truncations. Returned FileChanges group of - // channels will be closed, thus become unusable, after a deletion - // or truncation event. - // In order to properly report truncations, ChangeEvents requires - // the caller to pass their current offset in the file. - ChangeEvents(*tomb.Tomb, int64) (*FileChanges, error) -} diff --git a/vendor/github.com/nxadm/tail/winfile/winfile.go b/vendor/github.com/nxadm/tail/winfile/winfile.go deleted file mode 100644 index 4562ac7c25c..00000000000 --- a/vendor/github.com/nxadm/tail/winfile/winfile.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2019 FOSS contributors of https://github.com/nxadm/tail -// +build windows - -package winfile - -import ( - "os" - "syscall" - "unsafe" -) - -// issue also described here -//https://codereview.appspot.com/8203043/ - -// https://github.com/jnwhiteh/golang/blob/master/src/pkg/syscall/syscall_windows.go#L218 -func Open(path string, mode int, perm uint32) (fd syscall.Handle, err error) { - if len(path) == 0 { - return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND - } - pathp, err := syscall.UTF16PtrFromString(path) - if err != nil { - return syscall.InvalidHandle, err - } - var access uint32 - switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) { - case syscall.O_RDONLY: - access = syscall.GENERIC_READ - case syscall.O_WRONLY: - access = syscall.GENERIC_WRITE - case syscall.O_RDWR: - access = syscall.GENERIC_READ | syscall.GENERIC_WRITE - } - if mode&syscall.O_CREAT != 0 { - access |= syscall.GENERIC_WRITE - } - if mode&syscall.O_APPEND != 0 { - access &^= syscall.GENERIC_WRITE - access |= syscall.FILE_APPEND_DATA - } - sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE) - var sa *syscall.SecurityAttributes - if mode&syscall.O_CLOEXEC == 0 { - sa = makeInheritSa() - } - var createmode uint32 - switch { - case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL): - createmode = syscall.CREATE_NEW - case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC): - createmode = syscall.CREATE_ALWAYS - case mode&syscall.O_CREAT == syscall.O_CREAT: - createmode = syscall.OPEN_ALWAYS - case mode&syscall.O_TRUNC == syscall.O_TRUNC: - createmode = syscall.TRUNCATE_EXISTING - default: - createmode = syscall.OPEN_EXISTING - } - h, e := syscall.CreateFile(pathp, access, sharemode, sa, createmode, syscall.FILE_ATTRIBUTE_NORMAL, 0) - return h, e -} - -// https://github.com/jnwhiteh/golang/blob/master/src/pkg/syscall/syscall_windows.go#L211 -func makeInheritSa() *syscall.SecurityAttributes { - var sa syscall.SecurityAttributes - sa.Length = uint32(unsafe.Sizeof(sa)) - sa.InheritHandle = 1 - return &sa -} - -// https://github.com/jnwhiteh/golang/blob/master/src/pkg/os/file_windows.go#L133 -func OpenFile(name string, flag int, perm os.FileMode) (file *os.File, err error) { - r, e := Open(name, flag|syscall.O_CLOEXEC, syscallMode(perm)) - if e != nil { - return nil, e - } - return os.NewFile(uintptr(r), name), nil -} - -// https://github.com/jnwhiteh/golang/blob/master/src/pkg/os/file_posix.go#L61 -func syscallMode(i os.FileMode) (o uint32) { - o |= uint32(i.Perm()) - if i&os.ModeSetuid != 0 { - o |= syscall.S_ISUID - } - if i&os.ModeSetgid != 0 { - o |= syscall.S_ISGID - } - if i&os.ModeSticky != 0 { - o |= syscall.S_ISVTX - } - // No mapping for Go's ModeTemporary (plan9 only). - return -} diff --git a/vendor/github.com/onsi/ginkgo/config/config.go b/vendor/github.com/onsi/ginkgo/config/config.go deleted file mode 100644 index 3130c778974..00000000000 --- a/vendor/github.com/onsi/ginkgo/config/config.go +++ /dev/null @@ -1,232 +0,0 @@ -/* -Ginkgo accepts a number of configuration options. - -These are documented [here](http://onsi.github.io/ginkgo/#the-ginkgo-cli) - -You can also learn more via - - ginkgo help - -or (I kid you not): - - go test -asdf -*/ -package config - -import ( - "flag" - "time" - - "fmt" -) - -const VERSION = "1.16.5" - -type GinkgoConfigType struct { - RandomSeed int64 - RandomizeAllSpecs bool - RegexScansFilePath bool - FocusStrings []string - SkipStrings []string - SkipMeasurements bool - FailOnPending bool - FailFast bool - FlakeAttempts int - EmitSpecProgress bool - DryRun bool - DebugParallel bool - - ParallelNode int - ParallelTotal int - SyncHost string - StreamHost string -} - -var GinkgoConfig = GinkgoConfigType{} - -type DefaultReporterConfigType struct { - NoColor bool - SlowSpecThreshold float64 - NoisyPendings bool - NoisySkippings bool - Succinct bool - Verbose bool - FullTrace bool - ReportPassed bool - ReportFile string -} - -var DefaultReporterConfig = DefaultReporterConfigType{} - -func processPrefix(prefix string) string { - if prefix != "" { - prefix += "." - } - return prefix -} - -type flagFunc func(string) - -func (f flagFunc) String() string { return "" } -func (f flagFunc) Set(s string) error { f(s); return nil } - -func Flags(flagSet *flag.FlagSet, prefix string, includeParallelFlags bool) { - prefix = processPrefix(prefix) - flagSet.Int64Var(&(GinkgoConfig.RandomSeed), prefix+"seed", time.Now().Unix(), "The seed used to randomize the spec suite.") - flagSet.BoolVar(&(GinkgoConfig.RandomizeAllSpecs), prefix+"randomizeAllSpecs", false, "If set, ginkgo will randomize all specs together. By default, ginkgo only randomizes the top level Describe, Context and When groups.") - flagSet.BoolVar(&(GinkgoConfig.SkipMeasurements), prefix+"skipMeasurements", false, "If set, ginkgo will skip any measurement specs.") - flagSet.BoolVar(&(GinkgoConfig.FailOnPending), prefix+"failOnPending", false, "If set, ginkgo will mark the test suite as failed if any specs are pending.") - flagSet.BoolVar(&(GinkgoConfig.FailFast), prefix+"failFast", false, "If set, ginkgo will stop running a test suite after a failure occurs.") - - flagSet.BoolVar(&(GinkgoConfig.DryRun), prefix+"dryRun", false, "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v.") - - flagSet.Var(flagFunc(flagFocus), prefix+"focus", "If set, ginkgo will only run specs that match this regular expression. Can be specified multiple times, values are ORed.") - flagSet.Var(flagFunc(flagSkip), prefix+"skip", "If set, ginkgo will only run specs that do not match this regular expression. Can be specified multiple times, values are ORed.") - - flagSet.BoolVar(&(GinkgoConfig.RegexScansFilePath), prefix+"regexScansFilePath", false, "If set, ginkgo regex matching also will look at the file path (code location).") - - flagSet.IntVar(&(GinkgoConfig.FlakeAttempts), prefix+"flakeAttempts", 1, "Make up to this many attempts to run each spec. Please note that if any of the attempts succeed, the suite will not be failed. But any failures will still be recorded.") - - flagSet.BoolVar(&(GinkgoConfig.EmitSpecProgress), prefix+"progress", false, "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter.") - - flagSet.BoolVar(&(GinkgoConfig.DebugParallel), prefix+"debug", false, "If set, ginkgo will emit node output to files when running in parallel.") - - if includeParallelFlags { - flagSet.IntVar(&(GinkgoConfig.ParallelNode), prefix+"parallel.node", 1, "This worker node's (one-indexed) node number. For running specs in parallel.") - flagSet.IntVar(&(GinkgoConfig.ParallelTotal), prefix+"parallel.total", 1, "The total number of worker nodes. For running specs in parallel.") - flagSet.StringVar(&(GinkgoConfig.SyncHost), prefix+"parallel.synchost", "", "The address for the server that will synchronize the running nodes.") - flagSet.StringVar(&(GinkgoConfig.StreamHost), prefix+"parallel.streamhost", "", "The address for the server that the running nodes should stream data to.") - } - - flagSet.BoolVar(&(DefaultReporterConfig.NoColor), prefix+"noColor", false, "If set, suppress color output in default reporter.") - flagSet.Float64Var(&(DefaultReporterConfig.SlowSpecThreshold), prefix+"slowSpecThreshold", 5.0, "(in seconds) Specs that take longer to run than this threshold are flagged as slow by the default reporter.") - flagSet.BoolVar(&(DefaultReporterConfig.NoisyPendings), prefix+"noisyPendings", true, "If set, default reporter will shout about pending tests.") - flagSet.BoolVar(&(DefaultReporterConfig.NoisySkippings), prefix+"noisySkippings", true, "If set, default reporter will shout about skipping tests.") - flagSet.BoolVar(&(DefaultReporterConfig.Verbose), prefix+"v", false, "If set, default reporter print out all specs as they begin.") - flagSet.BoolVar(&(DefaultReporterConfig.Succinct), prefix+"succinct", false, "If set, default reporter prints out a very succinct report") - flagSet.BoolVar(&(DefaultReporterConfig.FullTrace), prefix+"trace", false, "If set, default reporter prints out the full stack trace when a failure occurs") - flagSet.BoolVar(&(DefaultReporterConfig.ReportPassed), prefix+"reportPassed", false, "If set, default reporter prints out captured output of passed tests.") - flagSet.StringVar(&(DefaultReporterConfig.ReportFile), prefix+"reportFile", "", "Override the default reporter output file path.") - -} - -func BuildFlagArgs(prefix string, ginkgo GinkgoConfigType, reporter DefaultReporterConfigType) []string { - prefix = processPrefix(prefix) - result := make([]string, 0) - - if ginkgo.RandomSeed > 0 { - result = append(result, fmt.Sprintf("--%sseed=%d", prefix, ginkgo.RandomSeed)) - } - - if ginkgo.RandomizeAllSpecs { - result = append(result, fmt.Sprintf("--%srandomizeAllSpecs", prefix)) - } - - if ginkgo.SkipMeasurements { - result = append(result, fmt.Sprintf("--%sskipMeasurements", prefix)) - } - - if ginkgo.FailOnPending { - result = append(result, fmt.Sprintf("--%sfailOnPending", prefix)) - } - - if ginkgo.FailFast { - result = append(result, fmt.Sprintf("--%sfailFast", prefix)) - } - - if ginkgo.DryRun { - result = append(result, fmt.Sprintf("--%sdryRun", prefix)) - } - - for _, s := range ginkgo.FocusStrings { - result = append(result, fmt.Sprintf("--%sfocus=%s", prefix, s)) - } - - for _, s := range ginkgo.SkipStrings { - result = append(result, fmt.Sprintf("--%sskip=%s", prefix, s)) - } - - if ginkgo.FlakeAttempts > 1 { - result = append(result, fmt.Sprintf("--%sflakeAttempts=%d", prefix, ginkgo.FlakeAttempts)) - } - - if ginkgo.EmitSpecProgress { - result = append(result, fmt.Sprintf("--%sprogress", prefix)) - } - - if ginkgo.DebugParallel { - result = append(result, fmt.Sprintf("--%sdebug", prefix)) - } - - if ginkgo.ParallelNode != 0 { - result = append(result, fmt.Sprintf("--%sparallel.node=%d", prefix, ginkgo.ParallelNode)) - } - - if ginkgo.ParallelTotal != 0 { - result = append(result, fmt.Sprintf("--%sparallel.total=%d", prefix, ginkgo.ParallelTotal)) - } - - if ginkgo.StreamHost != "" { - result = append(result, fmt.Sprintf("--%sparallel.streamhost=%s", prefix, ginkgo.StreamHost)) - } - - if ginkgo.SyncHost != "" { - result = append(result, fmt.Sprintf("--%sparallel.synchost=%s", prefix, ginkgo.SyncHost)) - } - - if ginkgo.RegexScansFilePath { - result = append(result, fmt.Sprintf("--%sregexScansFilePath", prefix)) - } - - if reporter.NoColor { - result = append(result, fmt.Sprintf("--%snoColor", prefix)) - } - - if reporter.SlowSpecThreshold > 0 { - result = append(result, fmt.Sprintf("--%sslowSpecThreshold=%.5f", prefix, reporter.SlowSpecThreshold)) - } - - if !reporter.NoisyPendings { - result = append(result, fmt.Sprintf("--%snoisyPendings=false", prefix)) - } - - if !reporter.NoisySkippings { - result = append(result, fmt.Sprintf("--%snoisySkippings=false", prefix)) - } - - if reporter.Verbose { - result = append(result, fmt.Sprintf("--%sv", prefix)) - } - - if reporter.Succinct { - result = append(result, fmt.Sprintf("--%ssuccinct", prefix)) - } - - if reporter.FullTrace { - result = append(result, fmt.Sprintf("--%strace", prefix)) - } - - if reporter.ReportPassed { - result = append(result, fmt.Sprintf("--%sreportPassed", prefix)) - } - - if reporter.ReportFile != "" { - result = append(result, fmt.Sprintf("--%sreportFile=%s", prefix, reporter.ReportFile)) - } - - return result -} - -// flagFocus implements the -focus flag. -func flagFocus(arg string) { - if arg != "" { - GinkgoConfig.FocusStrings = append(GinkgoConfig.FocusStrings, arg) - } -} - -// flagSkip implements the -skip flag. -func flagSkip(arg string) { - if arg != "" { - GinkgoConfig.SkipStrings = append(GinkgoConfig.SkipStrings, arg) - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go deleted file mode 100644 index ea10e979638..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/bootstrap_command.go +++ /dev/null @@ -1,201 +0,0 @@ -package main - -import ( - "bytes" - "flag" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - "text/template" - - "go/build" - - sprig "github.com/go-task/slim-sprig" - "github.com/onsi/ginkgo/ginkgo/nodot" -) - -func BuildBootstrapCommand() *Command { - var ( - agouti, noDot, internal bool - customBootstrapFile string - ) - flagSet := flag.NewFlagSet("bootstrap", flag.ExitOnError) - flagSet.BoolVar(&agouti, "agouti", false, "If set, bootstrap will generate a bootstrap file for writing Agouti tests") - flagSet.BoolVar(&noDot, "nodot", false, "If set, bootstrap will generate a bootstrap file that does not . import ginkgo and gomega") - flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name") - flagSet.StringVar(&customBootstrapFile, "template", "", "If specified, generate will use the contents of the file passed as the bootstrap template") - - return &Command{ - Name: "bootstrap", - FlagSet: flagSet, - UsageCommand: "ginkgo bootstrap ", - Usage: []string{ - "Bootstrap a test suite for the current package", - "Accepts the following flags:", - }, - Command: func(args []string, additionalArgs []string) { - generateBootstrap(agouti, noDot, internal, customBootstrapFile) - emitRCAdvertisement() - }, - } -} - -var bootstrapText = `package {{.Package}} - -import ( - "testing" - - {{.GinkgoImport}} - {{.GomegaImport}} -) - -func Test{{.FormattedName}}(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "{{.FormattedName}} Suite") -} -` - -var agoutiBootstrapText = `package {{.Package}} - -import ( - "testing" - - {{.GinkgoImport}} - {{.GomegaImport}} - "github.com/sclevine/agouti" -) - -func Test{{.FormattedName}}(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "{{.FormattedName}} Suite") -} - -var agoutiDriver *agouti.WebDriver - -var _ = BeforeSuite(func() { - // Choose a WebDriver: - - agoutiDriver = agouti.PhantomJS() - // agoutiDriver = agouti.Selenium() - // agoutiDriver = agouti.ChromeDriver() - - Expect(agoutiDriver.Start()).To(Succeed()) -}) - -var _ = AfterSuite(func() { - Expect(agoutiDriver.Stop()).To(Succeed()) -}) -` - -type bootstrapData struct { - Package string - FormattedName string - GinkgoImport string - GomegaImport string -} - -func getPackageAndFormattedName() (string, string, string) { - path, err := os.Getwd() - if err != nil { - complainAndQuit("Could not get current working directory: \n" + err.Error()) - } - - dirName := strings.Replace(filepath.Base(path), "-", "_", -1) - dirName = strings.Replace(dirName, " ", "_", -1) - - pkg, err := build.ImportDir(path, 0) - packageName := pkg.Name - if err != nil { - packageName = dirName - } - - formattedName := prettifyPackageName(filepath.Base(path)) - return packageName, dirName, formattedName -} - -func prettifyPackageName(name string) string { - name = strings.Replace(name, "-", " ", -1) - name = strings.Replace(name, "_", " ", -1) - name = strings.Title(name) - name = strings.Replace(name, " ", "", -1) - return name -} - -func determinePackageName(name string, internal bool) string { - if internal { - return name - } - - return name + "_test" -} - -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} - -func generateBootstrap(agouti, noDot, internal bool, customBootstrapFile string) { - packageName, bootstrapFilePrefix, formattedName := getPackageAndFormattedName() - data := bootstrapData{ - Package: determinePackageName(packageName, internal), - FormattedName: formattedName, - GinkgoImport: `. "github.com/onsi/ginkgo"`, - GomegaImport: `. "github.com/onsi/gomega"`, - } - - if noDot { - data.GinkgoImport = `"github.com/onsi/ginkgo"` - data.GomegaImport = `"github.com/onsi/gomega"` - } - - targetFile := fmt.Sprintf("%s_suite_test.go", bootstrapFilePrefix) - if fileExists(targetFile) { - fmt.Printf("%s already exists.\n\n", targetFile) - os.Exit(1) - } else { - fmt.Printf("Generating ginkgo test suite bootstrap for %s in:\n\t%s\n", packageName, targetFile) - } - - f, err := os.Create(targetFile) - if err != nil { - complainAndQuit("Could not create file: " + err.Error()) - panic(err.Error()) - } - defer f.Close() - - var templateText string - if customBootstrapFile != "" { - tpl, err := ioutil.ReadFile(customBootstrapFile) - if err != nil { - panic(err.Error()) - } - templateText = string(tpl) - } else if agouti { - templateText = agoutiBootstrapText - } else { - templateText = bootstrapText - } - - bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) - if err != nil { - panic(err.Error()) - } - - buf := &bytes.Buffer{} - bootstrapTemplate.Execute(buf, data) - - if noDot { - contents, err := nodot.ApplyNoDot(buf.Bytes()) - if err != nil { - complainAndQuit("Failed to import nodot declarations: " + err.Error()) - } - fmt.Println("To update the nodot declarations in the future, switch to this directory and run:\n\tginkgo nodot") - buf = bytes.NewBuffer(contents) - } - - buf.WriteTo(f) - - goFmt(targetFile) -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go deleted file mode 100644 index 2fddef0f7b4..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/build_command.go +++ /dev/null @@ -1,66 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - "path/filepath" - - "github.com/onsi/ginkgo/ginkgo/interrupthandler" - "github.com/onsi/ginkgo/ginkgo/testrunner" -) - -func BuildBuildCommand() *Command { - commandFlags := NewBuildCommandFlags(flag.NewFlagSet("build", flag.ExitOnError)) - interruptHandler := interrupthandler.NewInterruptHandler() - builder := &SpecBuilder{ - commandFlags: commandFlags, - interruptHandler: interruptHandler, - } - - return &Command{ - Name: "build", - FlagSet: commandFlags.FlagSet, - UsageCommand: "ginkgo build ", - Usage: []string{ - "Build the passed in (or the package in the current directory if left blank).", - "Accepts the following flags:", - }, - Command: builder.BuildSpecs, - } -} - -type SpecBuilder struct { - commandFlags *RunWatchAndBuildCommandFlags - interruptHandler *interrupthandler.InterruptHandler -} - -func (r *SpecBuilder) BuildSpecs(args []string, additionalArgs []string) { - r.commandFlags.computeNodes() - - suites, _ := findSuites(args, r.commandFlags.Recurse, r.commandFlags.SkipPackage, false) - - if len(suites) == 0 { - complainAndQuit("Found no test suites") - } - - passed := true - for _, suite := range suites { - runner := testrunner.New(suite, 1, false, 0, r.commandFlags.GoOpts, nil) - fmt.Printf("Compiling %s...\n", suite.PackageName) - - path, _ := filepath.Abs(filepath.Join(suite.Path, fmt.Sprintf("%s.test", suite.PackageName))) - err := runner.CompileTo(path) - if err != nil { - fmt.Println(err.Error()) - passed = false - } else { - fmt.Printf(" compiled %s.test\n", suite.PackageName) - } - } - - if passed { - os.Exit(0) - } - os.Exit(1) -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go deleted file mode 100644 index 02e2b3b328d..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/ginkgo_ast_nodes.go +++ /dev/null @@ -1,123 +0,0 @@ -package convert - -import ( - "fmt" - "go/ast" - "strings" - "unicode" -) - -/* - * Creates a func init() node - */ -func createVarUnderscoreBlock() *ast.ValueSpec { - valueSpec := &ast.ValueSpec{} - object := &ast.Object{Kind: 4, Name: "_", Decl: valueSpec, Data: 0} - ident := &ast.Ident{Name: "_", Obj: object} - valueSpec.Names = append(valueSpec.Names, ident) - return valueSpec -} - -/* - * Creates a Describe("Testing with ginkgo", func() { }) node - */ -func createDescribeBlock() *ast.CallExpr { - blockStatement := &ast.BlockStmt{List: []ast.Stmt{}} - - fieldList := &ast.FieldList{} - funcType := &ast.FuncType{Params: fieldList} - funcLit := &ast.FuncLit{Type: funcType, Body: blockStatement} - basicLit := &ast.BasicLit{Kind: 9, Value: "\"Testing with Ginkgo\""} - describeIdent := &ast.Ident{Name: "Describe"} - return &ast.CallExpr{Fun: describeIdent, Args: []ast.Expr{basicLit, funcLit}} -} - -/* - * Convenience function to return the name of the *testing.T param - * for a Test function that will be rewritten. This is useful because - * we will want to replace the usage of this named *testing.T inside the - * body of the function with a GinktoT. - */ -func namedTestingTArg(node *ast.FuncDecl) string { - return node.Type.Params.List[0].Names[0].Name // *exhale* -} - -/* - * Convenience function to return the block statement node for a Describe statement - */ -func blockStatementFromDescribe(desc *ast.CallExpr) *ast.BlockStmt { - var funcLit *ast.FuncLit - var found = false - - for _, node := range desc.Args { - switch node := node.(type) { - case *ast.FuncLit: - found = true - funcLit = node - break - } - } - - if !found { - panic("Error finding ast.FuncLit inside describe statement. Somebody done goofed.") - } - - return funcLit.Body -} - -/* convenience function for creating an It("TestNameHere") - * with all the body of the test function inside the anonymous - * func passed to It() - */ -func createItStatementForTestFunc(testFunc *ast.FuncDecl) *ast.ExprStmt { - blockStatement := &ast.BlockStmt{List: testFunc.Body.List} - fieldList := &ast.FieldList{} - funcType := &ast.FuncType{Params: fieldList} - funcLit := &ast.FuncLit{Type: funcType, Body: blockStatement} - - testName := rewriteTestName(testFunc.Name.Name) - basicLit := &ast.BasicLit{Kind: 9, Value: fmt.Sprintf("\"%s\"", testName)} - itBlockIdent := &ast.Ident{Name: "It"} - callExpr := &ast.CallExpr{Fun: itBlockIdent, Args: []ast.Expr{basicLit, funcLit}} - return &ast.ExprStmt{X: callExpr} -} - -/* -* rewrite test names to be human readable -* eg: rewrites "TestSomethingAmazing" as "something amazing" - */ -func rewriteTestName(testName string) string { - nameComponents := []string{} - currentString := "" - indexOfTest := strings.Index(testName, "Test") - if indexOfTest != 0 { - return testName - } - - testName = strings.Replace(testName, "Test", "", 1) - first, rest := testName[0], testName[1:] - testName = string(unicode.ToLower(rune(first))) + rest - - for _, rune := range testName { - if unicode.IsUpper(rune) { - nameComponents = append(nameComponents, currentString) - currentString = string(unicode.ToLower(rune)) - } else { - currentString += string(rune) - } - } - - return strings.Join(append(nameComponents, currentString), " ") -} - -func newGinkgoTFromIdent(ident *ast.Ident) *ast.CallExpr { - return &ast.CallExpr{ - Lparen: ident.NamePos + 1, - Rparen: ident.NamePos + 2, - Fun: &ast.Ident{Name: "GinkgoT"}, - } -} - -func newGinkgoTInterface() *ast.Ident { - return &ast.Ident{Name: "GinkgoTInterface"} -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go deleted file mode 100644 index 06c6ec94c90..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/import.go +++ /dev/null @@ -1,90 +0,0 @@ -package convert - -import ( - "fmt" - "go/ast" -) - -/* - * Given the root node of an AST, returns the node containing the - * import statements for the file. - */ -func importsForRootNode(rootNode *ast.File) (imports *ast.GenDecl, err error) { - for _, declaration := range rootNode.Decls { - decl, ok := declaration.(*ast.GenDecl) - if !ok || len(decl.Specs) == 0 { - continue - } - - _, ok = decl.Specs[0].(*ast.ImportSpec) - if ok { - imports = decl - return - } - } - - err = fmt.Errorf("Could not find imports for root node:\n\t%#v\n", rootNode) - return -} - -/* - * Removes "testing" import, if present - */ -func removeTestingImport(rootNode *ast.File) { - importDecl, err := importsForRootNode(rootNode) - if err != nil { - panic(err.Error()) - } - - var index int - for i, importSpec := range importDecl.Specs { - importSpec := importSpec.(*ast.ImportSpec) - if importSpec.Path.Value == "\"testing\"" { - index = i - break - } - } - - importDecl.Specs = append(importDecl.Specs[:index], importDecl.Specs[index+1:]...) -} - -/* - * Adds import statements for onsi/ginkgo, if missing - */ -func addGinkgoImports(rootNode *ast.File) { - importDecl, err := importsForRootNode(rootNode) - if err != nil { - panic(err.Error()) - } - - if len(importDecl.Specs) == 0 { - // TODO: might need to create a import decl here - panic("unimplemented : expected to find an imports block") - } - - needsGinkgo := true - for _, importSpec := range importDecl.Specs { - importSpec, ok := importSpec.(*ast.ImportSpec) - if !ok { - continue - } - - if importSpec.Path.Value == "\"github.com/onsi/ginkgo\"" { - needsGinkgo = false - } - } - - if needsGinkgo { - importDecl.Specs = append(importDecl.Specs, createImport(".", "\"github.com/onsi/ginkgo\"")) - } -} - -/* - * convenience function to create an import statement - */ -func createImport(name, path string) *ast.ImportSpec { - return &ast.ImportSpec{ - Name: &ast.Ident{Name: name}, - Path: &ast.BasicLit{Kind: 9, Value: path}, - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go deleted file mode 100644 index 363e52fe2f2..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/package_rewriter.go +++ /dev/null @@ -1,128 +0,0 @@ -package convert - -import ( - "fmt" - "go/build" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "regexp" -) - -/* - * RewritePackage takes a name (eg: my-package/tools), finds its test files using - * Go's build package, and then rewrites them. A ginkgo test suite file will - * also be added for this package, and all of its child packages. - */ -func RewritePackage(packageName string) { - pkg, err := packageWithName(packageName) - if err != nil { - panic(fmt.Sprintf("unexpected error reading package: '%s'\n%s\n", packageName, err.Error())) - } - - for _, filename := range findTestsInPackage(pkg) { - rewriteTestsInFile(filename) - } -} - -/* - * Given a package, findTestsInPackage reads the test files in the directory, - * and then recurses on each child package, returning a slice of all test files - * found in this process. - */ -func findTestsInPackage(pkg *build.Package) (testfiles []string) { - for _, file := range append(pkg.TestGoFiles, pkg.XTestGoFiles...) { - testfile, _ := filepath.Abs(filepath.Join(pkg.Dir, file)) - testfiles = append(testfiles, testfile) - } - - dirFiles, err := ioutil.ReadDir(pkg.Dir) - if err != nil { - panic(fmt.Sprintf("unexpected error reading dir: '%s'\n%s\n", pkg.Dir, err.Error())) - } - - re := regexp.MustCompile(`^[._]`) - - for _, file := range dirFiles { - if !file.IsDir() { - continue - } - - if re.Match([]byte(file.Name())) { - continue - } - - packageName := filepath.Join(pkg.ImportPath, file.Name()) - subPackage, err := packageWithName(packageName) - if err != nil { - panic(fmt.Sprintf("unexpected error reading package: '%s'\n%s\n", packageName, err.Error())) - } - - testfiles = append(testfiles, findTestsInPackage(subPackage)...) - } - - addGinkgoSuiteForPackage(pkg) - goFmtPackage(pkg) - return -} - -/* - * Shells out to `ginkgo bootstrap` to create a test suite file - */ -func addGinkgoSuiteForPackage(pkg *build.Package) { - originalDir, err := os.Getwd() - if err != nil { - panic(err) - } - - suite_test_file := filepath.Join(pkg.Dir, pkg.Name+"_suite_test.go") - - _, err = os.Stat(suite_test_file) - if err == nil { - return // test file already exists, this should be a no-op - } - - err = os.Chdir(pkg.Dir) - if err != nil { - panic(err) - } - - output, err := exec.Command("ginkgo", "bootstrap").Output() - - if err != nil { - panic(fmt.Sprintf("error running 'ginkgo bootstrap'.\nstdout: %s\n%s\n", output, err.Error())) - } - - err = os.Chdir(originalDir) - if err != nil { - panic(err) - } -} - -/* - * Shells out to `go fmt` to format the package - */ -func goFmtPackage(pkg *build.Package) { - path, _ := filepath.Abs(pkg.ImportPath) - output, err := exec.Command("go", "fmt", path).CombinedOutput() - - if err != nil { - fmt.Printf("Warning: Error running 'go fmt %s'.\nstdout: %s\n%s\n", path, output, err.Error()) - } -} - -/* - * Attempts to return a package with its test files already read. - * The ImportMode arg to build.Import lets you specify if you want go to read the - * buildable go files inside the package, but it fails if the package has no go files - */ -func packageWithName(name string) (pkg *build.Package, err error) { - pkg, err = build.Default.Import(name, ".", build.ImportMode(0)) - if err == nil { - return - } - - pkg, err = build.Default.Import(name, ".", build.ImportMode(1)) - return -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go deleted file mode 100644 index b33595c9ae1..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/test_finder.go +++ /dev/null @@ -1,56 +0,0 @@ -package convert - -import ( - "go/ast" - "regexp" -) - -/* - * Given a root node, walks its top level statements and returns - * points to function nodes to rewrite as It statements. - * These functions, according to Go testing convention, must be named - * TestWithCamelCasedName and receive a single *testing.T argument. - */ -func findTestFuncs(rootNode *ast.File) (testsToRewrite []*ast.FuncDecl) { - testNameRegexp := regexp.MustCompile("^Test[0-9A-Z].+") - - ast.Inspect(rootNode, func(node ast.Node) bool { - if node == nil { - return false - } - - switch node := node.(type) { - case *ast.FuncDecl: - matches := testNameRegexp.MatchString(node.Name.Name) - - if matches && receivesTestingT(node) { - testsToRewrite = append(testsToRewrite, node) - } - } - - return true - }) - - return -} - -/* - * convenience function that looks at args to a function and determines if its - * params include an argument of type *testing.T - */ -func receivesTestingT(node *ast.FuncDecl) bool { - if len(node.Type.Params.List) != 1 { - return false - } - - base, ok := node.Type.Params.List[0].Type.(*ast.StarExpr) - if !ok { - return false - } - - intermediate := base.X.(*ast.SelectorExpr) - isTestingPackage := intermediate.X.(*ast.Ident).Name == "testing" - isTestingT := intermediate.Sel.Name == "T" - - return isTestingPackage && isTestingT -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go deleted file mode 100644 index 60c73504ad9..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testfile_rewriter.go +++ /dev/null @@ -1,162 +0,0 @@ -package convert - -import ( - "bytes" - "fmt" - "go/ast" - "go/format" - "go/parser" - "go/token" - "io/ioutil" - "os" -) - -/* - * Given a file path, rewrites any tests in the Ginkgo format. - * First, we parse the AST, and update the imports declaration. - * Then, we walk the first child elements in the file, returning tests to rewrite. - * A top level init func is declared, with a single Describe func inside. - * Then the test functions to rewrite are inserted as It statements inside the Describe. - * Finally we walk the rest of the file, replacing other usages of *testing.T - * Once that is complete, we write the AST back out again to its file. - */ -func rewriteTestsInFile(pathToFile string) { - fileSet := token.NewFileSet() - rootNode, err := parser.ParseFile(fileSet, pathToFile, nil, parser.ParseComments) - if err != nil { - panic(fmt.Sprintf("Error parsing test file '%s':\n%s\n", pathToFile, err.Error())) - } - - addGinkgoImports(rootNode) - removeTestingImport(rootNode) - - varUnderscoreBlock := createVarUnderscoreBlock() - describeBlock := createDescribeBlock() - varUnderscoreBlock.Values = []ast.Expr{describeBlock} - - for _, testFunc := range findTestFuncs(rootNode) { - rewriteTestFuncAsItStatement(testFunc, rootNode, describeBlock) - } - - underscoreDecl := &ast.GenDecl{ - Tok: 85, // gah, magick numbers are needed to make this work - TokPos: 14, // this tricks Go into writing "var _ = Describe" - Specs: []ast.Spec{varUnderscoreBlock}, - } - - imports := rootNode.Decls[0] - tail := rootNode.Decls[1:] - rootNode.Decls = append(append([]ast.Decl{imports}, underscoreDecl), tail...) - rewriteOtherFuncsToUseGinkgoT(rootNode.Decls) - walkNodesInRootNodeReplacingTestingT(rootNode) - - var buffer bytes.Buffer - if err = format.Node(&buffer, fileSet, rootNode); err != nil { - panic(fmt.Sprintf("Error formatting ast node after rewriting tests.\n%s\n", err.Error())) - } - - fileInfo, err := os.Stat(pathToFile) - - if err != nil { - panic(fmt.Sprintf("Error stat'ing file: %s\n", pathToFile)) - } - - err = ioutil.WriteFile(pathToFile, buffer.Bytes(), fileInfo.Mode()) -} - -/* - * Given a test func named TestDoesSomethingNeat, rewrites it as - * It("does something neat", func() { __test_body_here__ }) and adds it - * to the Describe's list of statements - */ -func rewriteTestFuncAsItStatement(testFunc *ast.FuncDecl, rootNode *ast.File, describe *ast.CallExpr) { - var funcIndex int = -1 - for index, child := range rootNode.Decls { - if child == testFunc { - funcIndex = index - break - } - } - - if funcIndex < 0 { - panic(fmt.Sprintf("Assert failed: Error finding index for test node %s\n", testFunc.Name.Name)) - } - - var block *ast.BlockStmt = blockStatementFromDescribe(describe) - block.List = append(block.List, createItStatementForTestFunc(testFunc)) - replaceTestingTsWithGinkgoT(block, namedTestingTArg(testFunc)) - - // remove the old test func from the root node's declarations - rootNode.Decls = append(rootNode.Decls[:funcIndex], rootNode.Decls[funcIndex+1:]...) -} - -/* - * walks nodes inside of a test func's statements and replaces the usage of - * it's named *testing.T param with GinkgoT's - */ -func replaceTestingTsWithGinkgoT(statementsBlock *ast.BlockStmt, testingT string) { - ast.Inspect(statementsBlock, func(node ast.Node) bool { - if node == nil { - return false - } - - keyValueExpr, ok := node.(*ast.KeyValueExpr) - if ok { - replaceNamedTestingTsInKeyValueExpression(keyValueExpr, testingT) - return true - } - - funcLiteral, ok := node.(*ast.FuncLit) - if ok { - replaceTypeDeclTestingTsInFuncLiteral(funcLiteral) - return true - } - - callExpr, ok := node.(*ast.CallExpr) - if !ok { - return true - } - replaceTestingTsInArgsLists(callExpr, testingT) - - funCall, ok := callExpr.Fun.(*ast.SelectorExpr) - if ok { - replaceTestingTsMethodCalls(funCall, testingT) - } - - return true - }) -} - -/* - * rewrite t.Fail() or any other *testing.T method by replacing with T().Fail() - * This function receives a selector expression (eg: t.Fail()) and - * the name of the *testing.T param from the function declaration. Rewrites the - * selector expression in place if the target was a *testing.T - */ -func replaceTestingTsMethodCalls(selectorExpr *ast.SelectorExpr, testingT string) { - ident, ok := selectorExpr.X.(*ast.Ident) - if !ok { - return - } - - if ident.Name == testingT { - selectorExpr.X = newGinkgoTFromIdent(ident) - } -} - -/* - * replaces usages of a named *testing.T param inside of a call expression - * with a new GinkgoT object - */ -func replaceTestingTsInArgsLists(callExpr *ast.CallExpr, testingT string) { - for index, arg := range callExpr.Args { - ident, ok := arg.(*ast.Ident) - if !ok { - continue - } - - if ident.Name == testingT { - callExpr.Args[index] = newGinkgoTFromIdent(ident) - } - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go deleted file mode 100644 index 418cdc4e563..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert/testing_t_rewriter.go +++ /dev/null @@ -1,130 +0,0 @@ -package convert - -import ( - "go/ast" -) - -/* - * Rewrites any other top level funcs that receive a *testing.T param - */ -func rewriteOtherFuncsToUseGinkgoT(declarations []ast.Decl) { - for _, decl := range declarations { - decl, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - - for _, param := range decl.Type.Params.List { - starExpr, ok := param.Type.(*ast.StarExpr) - if !ok { - continue - } - - selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) - if !ok { - continue - } - - xIdent, ok := selectorExpr.X.(*ast.Ident) - if !ok || xIdent.Name != "testing" { - continue - } - - if selectorExpr.Sel.Name != "T" { - continue - } - - param.Type = newGinkgoTInterface() - } - } -} - -/* - * Walks all of the nodes in the file, replacing *testing.T in struct - * and func literal nodes. eg: - * type foo struct { *testing.T } - * var bar = func(t *testing.T) { } - */ -func walkNodesInRootNodeReplacingTestingT(rootNode *ast.File) { - ast.Inspect(rootNode, func(node ast.Node) bool { - if node == nil { - return false - } - - switch node := node.(type) { - case *ast.StructType: - replaceTestingTsInStructType(node) - case *ast.FuncLit: - replaceTypeDeclTestingTsInFuncLiteral(node) - } - - return true - }) -} - -/* - * replaces named *testing.T inside a composite literal - */ -func replaceNamedTestingTsInKeyValueExpression(kve *ast.KeyValueExpr, testingT string) { - ident, ok := kve.Value.(*ast.Ident) - if !ok { - return - } - - if ident.Name == testingT { - kve.Value = newGinkgoTFromIdent(ident) - } -} - -/* - * replaces *testing.T params in a func literal with GinkgoT - */ -func replaceTypeDeclTestingTsInFuncLiteral(functionLiteral *ast.FuncLit) { - for _, arg := range functionLiteral.Type.Params.List { - starExpr, ok := arg.Type.(*ast.StarExpr) - if !ok { - continue - } - - selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) - if !ok { - continue - } - - target, ok := selectorExpr.X.(*ast.Ident) - if !ok { - continue - } - - if target.Name == "testing" && selectorExpr.Sel.Name == "T" { - arg.Type = newGinkgoTInterface() - } - } -} - -/* - * Replaces *testing.T types inside of a struct declaration with a GinkgoT - * eg: type foo struct { *testing.T } - */ -func replaceTestingTsInStructType(structType *ast.StructType) { - for _, field := range structType.Fields.List { - starExpr, ok := field.Type.(*ast.StarExpr) - if !ok { - continue - } - - selectorExpr, ok := starExpr.X.(*ast.SelectorExpr) - if !ok { - continue - } - - xIdent, ok := selectorExpr.X.(*ast.Ident) - if !ok { - continue - } - - if xIdent.Name == "testing" && selectorExpr.Sel.Name == "T" { - field.Type = newGinkgoTInterface() - } - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go deleted file mode 100644 index 8e99f56a23c..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/convert_command.go +++ /dev/null @@ -1,51 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "os" - - "github.com/onsi/ginkgo/ginkgo/convert" - colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" - "github.com/onsi/ginkgo/types" -) - -func BuildConvertCommand() *Command { - return &Command{ - Name: "convert", - FlagSet: flag.NewFlagSet("convert", flag.ExitOnError), - UsageCommand: "ginkgo convert /path/to/package", - Usage: []string{ - "Convert the package at the passed in path from an XUnit-style test to a Ginkgo-style test", - }, - Command: convertPackage, - } -} - -func convertPackage(args []string, additionalArgs []string) { - deprecationTracker := types.NewDeprecationTracker() - deprecationTracker.TrackDeprecation(types.Deprecations.Convert()) - fmt.Fprintln(colorable.NewColorableStderr(), deprecationTracker.DeprecationsReport()) - - if len(args) != 1 { - println(fmt.Sprintf("usage: ginkgo convert /path/to/your/package")) - os.Exit(1) - } - - defer func() { - err := recover() - if err != nil { - switch err := err.(type) { - case error: - println(err.Error()) - case string: - println(err) - default: - println(fmt.Sprintf("unexpected error: %#v", err)) - } - os.Exit(1) - } - }() - - convert.RewritePackage(args[0]) -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go deleted file mode 100644 index db3f40406dc..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/help_command.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "flag" - "fmt" -) - -func BuildHelpCommand() *Command { - return &Command{ - Name: "help", - FlagSet: flag.NewFlagSet("help", flag.ExitOnError), - UsageCommand: "ginkgo help ", - Usage: []string{ - "Print usage information. If a command is passed in, print usage information just for that command.", - }, - Command: printHelp, - } -} - -func printHelp(args []string, additionalArgs []string) { - if len(args) == 0 { - usage() - emitRCAdvertisement() - } else { - command, found := commandMatching(args[0]) - if !found { - complainAndQuit(fmt.Sprintf("Unknown command: %s", args[0])) - } - - usageForCommand(command, true) - emitRCAdvertisement() - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go b/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go deleted file mode 100644 index ec456bf2922..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/interrupt_handler.go +++ /dev/null @@ -1,52 +0,0 @@ -package interrupthandler - -import ( - "os" - "os/signal" - "sync" - "syscall" -) - -type InterruptHandler struct { - interruptCount int - lock *sync.Mutex - C chan bool -} - -func NewInterruptHandler() *InterruptHandler { - h := &InterruptHandler{ - lock: &sync.Mutex{}, - C: make(chan bool), - } - - go h.handleInterrupt() - SwallowSigQuit() - - return h -} - -func (h *InterruptHandler) WasInterrupted() bool { - h.lock.Lock() - defer h.lock.Unlock() - - return h.interruptCount > 0 -} - -func (h *InterruptHandler) handleInterrupt() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - - <-c - signal.Stop(c) - - h.lock.Lock() - h.interruptCount++ - if h.interruptCount == 1 { - close(h.C) - } else if h.interruptCount > 5 { - os.Exit(1) - } - h.lock.Unlock() - - go h.handleInterrupt() -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/main.go b/vendor/github.com/onsi/ginkgo/ginkgo/main.go deleted file mode 100644 index ae0e1daf61e..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/main.go +++ /dev/null @@ -1,337 +0,0 @@ -/* -The Ginkgo CLI - -The Ginkgo CLI is fully documented [here](http://onsi.github.io/ginkgo/#the_ginkgo_cli) - -You can also learn more by running: - - ginkgo help - -Here are some of the more commonly used commands: - -To install: - - go install github.com/onsi/ginkgo/ginkgo - -To run tests: - - ginkgo - -To run tests in all subdirectories: - - ginkgo -r - -To run tests in particular packages: - - ginkgo /path/to/package /path/to/another/package - -To pass arguments/flags to your tests: - - ginkgo -- - -To run tests in parallel - - ginkgo -p - -this will automatically detect the optimal number of nodes to use. Alternatively, you can specify the number of nodes with: - - ginkgo -nodes=N - -(note that you don't need to provide -p in this case). - -By default the Ginkgo CLI will spin up a server that the individual test processes send test output to. The CLI aggregates this output and then presents coherent test output, one test at a time, as each test completes. -An alternative is to have the parallel nodes run and stream interleaved output back. This useful for debugging, particularly in contexts where tests hang/fail to start. To get this interleaved output: - - ginkgo -nodes=N -stream=true - -On windows, the default value for stream is true. - -By default, when running multiple tests (with -r or a list of packages) Ginkgo will abort when a test fails. To have Ginkgo run subsequent test suites instead you can: - - ginkgo -keepGoing - -To fail if there are ginkgo tests in a directory but no test suite (missing `RunSpecs`) - - ginkgo -requireSuite - -To monitor packages and rerun tests when changes occur: - - ginkgo watch <-r> - -passing `ginkgo watch` the `-r` flag will recursively detect all test suites under the current directory and monitor them. -`watch` does not detect *new* packages. Moreover, changes in package X only rerun the tests for package X, tests for packages -that depend on X are not rerun. - -[OSX & Linux only] To receive (desktop) notifications when a test run completes: - - ginkgo -notify - -this is particularly useful with `ginkgo watch`. Notifications are currently only supported on OS X and require that you `brew install terminal-notifier` - -Sometimes (to suss out race conditions/flakey tests, for example) you want to keep running a test suite until it fails. You can do this with: - - ginkgo -untilItFails - -To bootstrap a test suite: - - ginkgo bootstrap - -To generate a test file: - - ginkgo generate - -To bootstrap/generate test files without using "." imports: - - ginkgo bootstrap --nodot - ginkgo generate --nodot - -this will explicitly export all the identifiers in Ginkgo and Gomega allowing you to rename them to avoid collisions. When you pull to the latest Ginkgo/Gomega you'll want to run - - ginkgo nodot - -to refresh this list and pull in any new identifiers. In particular, this will pull in any new Gomega matchers that get added. - -To convert an existing XUnit style test suite to a Ginkgo-style test suite: - - ginkgo convert . - -To unfocus tests: - - ginkgo unfocus - -or - - ginkgo blur - -To compile a test suite: - - ginkgo build - -will output an executable file named `package.test`. This can be run directly or by invoking - - ginkgo - - -To print an outline of Ginkgo specs and containers in a file: - - gingko outline - -To print out Ginkgo's version: - - ginkgo version - -To get more help: - - ginkgo help -*/ -package main - -import ( - "flag" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/formatter" - "github.com/onsi/ginkgo/ginkgo/testsuite" -) - -const greenColor = "\x1b[32m" -const redColor = "\x1b[91m" -const defaultStyle = "\x1b[0m" -const lightGrayColor = "\x1b[37m" - -type Command struct { - Name string - AltName string - FlagSet *flag.FlagSet - Usage []string - UsageCommand string - Command func(args []string, additionalArgs []string) - SuppressFlagDocumentation bool - FlagDocSubstitute []string -} - -func (c *Command) Matches(name string) bool { - return c.Name == name || (c.AltName != "" && c.AltName == name) -} - -func (c *Command) Run(args []string, additionalArgs []string) { - c.FlagSet.Usage = usage - c.FlagSet.Parse(args) - c.Command(c.FlagSet.Args(), additionalArgs) -} - -var DefaultCommand *Command -var Commands []*Command - -func init() { - DefaultCommand = BuildRunCommand() - Commands = append(Commands, BuildWatchCommand()) - Commands = append(Commands, BuildBuildCommand()) - Commands = append(Commands, BuildBootstrapCommand()) - Commands = append(Commands, BuildGenerateCommand()) - Commands = append(Commands, BuildNodotCommand()) - Commands = append(Commands, BuildConvertCommand()) - Commands = append(Commands, BuildUnfocusCommand()) - Commands = append(Commands, BuildVersionCommand()) - Commands = append(Commands, BuildHelpCommand()) - Commands = append(Commands, BuildOutlineCommand()) -} - -func main() { - args := []string{} - additionalArgs := []string{} - - foundDelimiter := false - - for _, arg := range os.Args[1:] { - if !foundDelimiter { - if arg == "--" { - foundDelimiter = true - continue - } - } - - if foundDelimiter { - additionalArgs = append(additionalArgs, arg) - } else { - args = append(args, arg) - } - } - - if len(args) > 0 { - commandToRun, found := commandMatching(args[0]) - if found { - commandToRun.Run(args[1:], additionalArgs) - return - } - } - - DefaultCommand.Run(args, additionalArgs) -} - -func commandMatching(name string) (*Command, bool) { - for _, command := range Commands { - if command.Matches(name) { - return command, true - } - } - return nil, false -} - -func usage() { - fmt.Printf("Ginkgo Version %s\n\n", config.VERSION) - usageForCommand(DefaultCommand, false) - for _, command := range Commands { - fmt.Printf("\n") - usageForCommand(command, false) - } -} - -func usageForCommand(command *Command, longForm bool) { - fmt.Printf("%s\n%s\n", command.UsageCommand, strings.Repeat("-", len(command.UsageCommand))) - fmt.Printf("%s\n", strings.Join(command.Usage, "\n")) - if command.SuppressFlagDocumentation && !longForm { - fmt.Printf("%s\n", strings.Join(command.FlagDocSubstitute, "\n ")) - } else { - command.FlagSet.SetOutput(os.Stdout) - command.FlagSet.PrintDefaults() - } -} - -func complainAndQuit(complaint string) { - fmt.Fprintf(os.Stderr, "%s\nFor usage instructions:\n\tginkgo help\n", complaint) - emitRCAdvertisement() - os.Exit(1) -} - -func findSuites(args []string, recurseForAll bool, skipPackage string, allowPrecompiled bool) ([]testsuite.TestSuite, []string) { - suites := []testsuite.TestSuite{} - - if len(args) > 0 { - for _, arg := range args { - if allowPrecompiled { - suite, err := testsuite.PrecompiledTestSuite(arg) - if err == nil { - suites = append(suites, suite) - continue - } - } - recurseForSuite := recurseForAll - if strings.HasSuffix(arg, "/...") && arg != "/..." { - arg = arg[:len(arg)-4] - recurseForSuite = true - } - suites = append(suites, testsuite.SuitesInDir(arg, recurseForSuite)...) - } - } else { - suites = testsuite.SuitesInDir(".", recurseForAll) - } - - skippedPackages := []string{} - if skipPackage != "" { - skipFilters := strings.Split(skipPackage, ",") - filteredSuites := []testsuite.TestSuite{} - for _, suite := range suites { - skip := false - for _, skipFilter := range skipFilters { - if strings.Contains(suite.Path, skipFilter) { - skip = true - break - } - } - if skip { - skippedPackages = append(skippedPackages, suite.Path) - } else { - filteredSuites = append(filteredSuites, suite) - } - } - suites = filteredSuites - } - - return suites, skippedPackages -} - -func goFmt(path string) { - out, err := exec.Command("go", "fmt", path).CombinedOutput() - if err != nil { - complainAndQuit("Could not fmt: " + err.Error() + "\n" + string(out)) - } -} - -func pluralizedWord(singular, plural string, count int) string { - if count == 1 { - return singular - } - return plural -} - -func emitRCAdvertisement() { - ackRC := os.Getenv("ACK_GINKGO_RC") - if ackRC != "" { - return - } - home, err := os.UserHomeDir() - if err == nil { - _, err := os.Stat(filepath.Join(home, ".ack-ginkgo-rc")) - if err == nil { - return - } - } - - out := formatter.F("\n{{light-yellow}}Ginkgo 2.0 is coming soon!{{/}}\n") - out += formatter.F("{{light-yellow}}=========================={{/}}\n") - out += formatter.F("{{bold}}{{green}}Ginkgo 2.0{{/}} is under active development and will introduce several new features, improvements, and a small handful of breaking changes.\n") - out += formatter.F("A release candidate for 2.0 is now available and 2.0 should GA in Fall 2021. {{bold}}Please give the RC a try and send us feedback!{{/}}\n") - out += formatter.F(" - To learn more, view the migration guide at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/ver2/docs/MIGRATING_TO_V2.md{{/}}\n") - out += formatter.F(" - For instructions on using the Release Candidate visit {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/ver2/docs/MIGRATING_TO_V2.md#using-the-beta{{/}}\n") - out += formatter.F(" - To comment, chime in at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/issues/711{{/}}\n\n") - out += formatter.F("To {{bold}}{{coral}}silence this notice{{/}}, set the environment variable: {{bold}}ACK_GINKGO_RC=true{{/}}\n") - out += formatter.F("Alternatively you can: {{bold}}touch $HOME/.ack-ginkgo-rc{{/}}") - - fmt.Println(out) -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go b/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go deleted file mode 100644 index c87b721652f..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/nodot/nodot.go +++ /dev/null @@ -1,196 +0,0 @@ -package nodot - -import ( - "fmt" - "go/ast" - "go/build" - "go/parser" - "go/token" - "path/filepath" - "strings" -) - -func ApplyNoDot(data []byte) ([]byte, error) { - sections, err := generateNodotSections() - if err != nil { - return nil, err - } - - for _, section := range sections { - data = section.createOrUpdateIn(data) - } - - return data, nil -} - -type nodotSection struct { - name string - pkg string - declarations []string - types []string -} - -func (s nodotSection) createOrUpdateIn(data []byte) []byte { - renames := map[string]string{} - - contents := string(data) - - lines := strings.Split(contents, "\n") - - comment := "// Declarations for " + s.name - - newLines := []string{} - for _, line := range lines { - if line == comment { - continue - } - - words := strings.Split(line, " ") - lastWord := words[len(words)-1] - - if s.containsDeclarationOrType(lastWord) { - renames[lastWord] = words[1] - continue - } - - newLines = append(newLines, line) - } - - if len(newLines[len(newLines)-1]) > 0 { - newLines = append(newLines, "") - } - - newLines = append(newLines, comment) - - for _, typ := range s.types { - name, ok := renames[s.prefix(typ)] - if !ok { - name = typ - } - newLines = append(newLines, fmt.Sprintf("type %s %s", name, s.prefix(typ))) - } - - for _, decl := range s.declarations { - name, ok := renames[s.prefix(decl)] - if !ok { - name = decl - } - newLines = append(newLines, fmt.Sprintf("var %s = %s", name, s.prefix(decl))) - } - - newLines = append(newLines, "") - - newContents := strings.Join(newLines, "\n") - - return []byte(newContents) -} - -func (s nodotSection) prefix(declOrType string) string { - return s.pkg + "." + declOrType -} - -func (s nodotSection) containsDeclarationOrType(word string) bool { - for _, declaration := range s.declarations { - if s.prefix(declaration) == word { - return true - } - } - - for _, typ := range s.types { - if s.prefix(typ) == word { - return true - } - } - - return false -} - -func generateNodotSections() ([]nodotSection, error) { - sections := []nodotSection{} - - declarations, err := getExportedDeclerationsForPackage("github.com/onsi/ginkgo", "ginkgo_dsl.go", "GINKGO_VERSION", "GINKGO_PANIC") - if err != nil { - return nil, err - } - sections = append(sections, nodotSection{ - name: "Ginkgo DSL", - pkg: "ginkgo", - declarations: declarations, - types: []string{"Done", "Benchmarker"}, - }) - - declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "gomega_dsl.go", "GOMEGA_VERSION") - if err != nil { - return nil, err - } - sections = append(sections, nodotSection{ - name: "Gomega DSL", - pkg: "gomega", - declarations: declarations, - }) - - declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "matchers.go") - if err != nil { - return nil, err - } - sections = append(sections, nodotSection{ - name: "Gomega Matchers", - pkg: "gomega", - declarations: declarations, - }) - - return sections, nil -} - -func getExportedDeclerationsForPackage(pkgPath string, filename string, blacklist ...string) ([]string, error) { - pkg, err := build.Import(pkgPath, ".", 0) - if err != nil { - return []string{}, err - } - - declarations, err := getExportedDeclarationsForFile(filepath.Join(pkg.Dir, filename)) - if err != nil { - return []string{}, err - } - - blacklistLookup := map[string]bool{} - for _, declaration := range blacklist { - blacklistLookup[declaration] = true - } - - filteredDeclarations := []string{} - for _, declaration := range declarations { - if blacklistLookup[declaration] { - continue - } - filteredDeclarations = append(filteredDeclarations, declaration) - } - - return filteredDeclarations, nil -} - -func getExportedDeclarationsForFile(path string) ([]string, error) { - fset := token.NewFileSet() - tree, err := parser.ParseFile(fset, path, nil, 0) - if err != nil { - return []string{}, err - } - - declarations := []string{} - ast.FileExports(tree) - for _, decl := range tree.Decls { - switch x := decl.(type) { - case *ast.GenDecl: - switch s := x.Specs[0].(type) { - case *ast.ValueSpec: - declarations = append(declarations, s.Names[0].Name) - } - case *ast.FuncDecl: - if x.Recv == nil { - declarations = append(declarations, x.Name.Name) - } - } - } - - return declarations, nil -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go deleted file mode 100644 index 39b88b5d1bf..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/nodot_command.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import ( - "bufio" - "flag" - "io/ioutil" - "os" - "path/filepath" - "regexp" - - "github.com/onsi/ginkgo/ginkgo/nodot" -) - -func BuildNodotCommand() *Command { - return &Command{ - Name: "nodot", - FlagSet: flag.NewFlagSet("bootstrap", flag.ExitOnError), - UsageCommand: "ginkgo nodot", - Usage: []string{ - "Update the nodot declarations in your test suite", - "Any missing declarations (from, say, a recently added matcher) will be added to your bootstrap file.", - "If you've renamed a declaration, that name will be honored and not overwritten.", - }, - Command: updateNodot, - } -} - -func updateNodot(args []string, additionalArgs []string) { - suiteFile, perm := findSuiteFile() - - data, err := ioutil.ReadFile(suiteFile) - if err != nil { - complainAndQuit("Failed to update nodot declarations: " + err.Error()) - } - - content, err := nodot.ApplyNoDot(data) - if err != nil { - complainAndQuit("Failed to update nodot declarations: " + err.Error()) - } - ioutil.WriteFile(suiteFile, content, perm) - - goFmt(suiteFile) -} - -func findSuiteFile() (string, os.FileMode) { - workingDir, err := os.Getwd() - if err != nil { - complainAndQuit("Could not find suite file for nodot: " + err.Error()) - } - - files, err := ioutil.ReadDir(workingDir) - if err != nil { - complainAndQuit("Could not find suite file for nodot: " + err.Error()) - } - - re := regexp.MustCompile(`RunSpecs\(|RunSpecsWithDefaultAndCustomReporters\(|RunSpecsWithCustomReporters\(`) - - for _, file := range files { - if file.IsDir() { - continue - } - path := filepath.Join(workingDir, file.Name()) - f, err := os.Open(path) - if err != nil { - complainAndQuit("Could not find suite file for nodot: " + err.Error()) - } - defer f.Close() - - if re.MatchReader(bufio.NewReader(f)) { - return path, file.Mode() - } - } - - complainAndQuit("Could not find a suite file for nodot: you need a bootstrap file that call's Ginkgo's RunSpecs() command.\nTry running ginkgo bootstrap first.") - - return "", 0 -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go b/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go deleted file mode 100644 index 368d61fb31c..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/notifications.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "fmt" - "os" - "os/exec" - "regexp" - "runtime" - "strings" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/ginkgo/testsuite" -) - -type Notifier struct { - commandFlags *RunWatchAndBuildCommandFlags -} - -func NewNotifier(commandFlags *RunWatchAndBuildCommandFlags) *Notifier { - return &Notifier{ - commandFlags: commandFlags, - } -} - -func (n *Notifier) VerifyNotificationsAreAvailable() { - if n.commandFlags.Notify { - onLinux := (runtime.GOOS == "linux") - onOSX := (runtime.GOOS == "darwin") - if onOSX { - - _, err := exec.LookPath("terminal-notifier") - if err != nil { - fmt.Printf(`--notify requires terminal-notifier, which you don't seem to have installed. - -OSX: - -To remedy this: - - brew install terminal-notifier - -To learn more about terminal-notifier: - - https://github.com/alloy/terminal-notifier -`) - os.Exit(1) - } - - } else if onLinux { - - _, err := exec.LookPath("notify-send") - if err != nil { - fmt.Printf(`--notify requires terminal-notifier or notify-send, which you don't seem to have installed. - -Linux: - -Download and install notify-send for your distribution -`) - os.Exit(1) - } - - } - } -} - -func (n *Notifier) SendSuiteCompletionNotification(suite testsuite.TestSuite, suitePassed bool) { - if suitePassed { - n.SendNotification("Ginkgo [PASS]", fmt.Sprintf(`Test suite for "%s" passed.`, suite.PackageName)) - } else { - n.SendNotification("Ginkgo [FAIL]", fmt.Sprintf(`Test suite for "%s" failed.`, suite.PackageName)) - } -} - -func (n *Notifier) SendNotification(title string, subtitle string) { - - if n.commandFlags.Notify { - onLinux := (runtime.GOOS == "linux") - onOSX := (runtime.GOOS == "darwin") - - if onOSX { - - _, err := exec.LookPath("terminal-notifier") - if err == nil { - args := []string{"-title", title, "-subtitle", subtitle, "-group", "com.onsi.ginkgo"} - terminal := os.Getenv("TERM_PROGRAM") - if terminal == "iTerm.app" { - args = append(args, "-activate", "com.googlecode.iterm2") - } else if terminal == "Apple_Terminal" { - args = append(args, "-activate", "com.apple.Terminal") - } - - exec.Command("terminal-notifier", args...).Run() - } - - } else if onLinux { - - _, err := exec.LookPath("notify-send") - if err == nil { - args := []string{"-a", "ginkgo", title, subtitle} - exec.Command("notify-send", args...).Run() - } - - } - } -} - -func (n *Notifier) RunCommand(suite testsuite.TestSuite, suitePassed bool) { - - command := n.commandFlags.AfterSuiteHook - if command != "" { - - // Allow for string replacement to pass input to the command - passed := "[FAIL]" - if suitePassed { - passed = "[PASS]" - } - command = strings.Replace(command, "(ginkgo-suite-passed)", passed, -1) - command = strings.Replace(command, "(ginkgo-suite-name)", suite.PackageName, -1) - - // Must break command into parts - splitArgs := regexp.MustCompile(`'.+'|".+"|\S+`) - parts := splitArgs.FindAllString(command, -1) - - output, err := exec.Command(parts[0], parts[1:]...).CombinedOutput() - if err != nil { - fmt.Println("Post-suite command failed:") - if config.DefaultReporterConfig.NoColor { - fmt.Printf("\t%s\n", output) - } else { - fmt.Printf("\t%s%s%s\n", redColor, string(output), defaultStyle) - } - n.SendNotification("Ginkgo [ERROR]", fmt.Sprintf(`After suite command "%s" failed`, n.commandFlags.AfterSuiteHook)) - } else { - fmt.Println("Post-suite command succeeded:") - if config.DefaultReporterConfig.NoColor { - fmt.Printf("\t%s\n", output) - } else { - fmt.Printf("\t%s%s%s\n", greenColor, string(output), defaultStyle) - } - } - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go deleted file mode 100644 index 96ca7ad27fb..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/outline_command.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "encoding/json" - "flag" - "fmt" - "go/parser" - "go/token" - "os" - - "github.com/onsi/ginkgo/ginkgo/outline" -) - -const ( - // indentWidth is the width used by the 'indent' output - indentWidth = 4 - // stdinAlias is a portable alias for stdin. This convention is used in - // other CLIs, e.g., kubectl. - stdinAlias = "-" - usageCommand = "ginkgo outline " -) - -func BuildOutlineCommand() *Command { - const defaultFormat = "csv" - var format string - flagSet := flag.NewFlagSet("outline", flag.ExitOnError) - flagSet.StringVar(&format, "format", defaultFormat, "Format of outline. Accepted: 'csv', 'indent', 'json'") - return &Command{ - Name: "outline", - FlagSet: flagSet, - UsageCommand: usageCommand, - Usage: []string{ - "Create an outline of Ginkgo symbols for a file", - "To read from stdin, use: `ginkgo outline -`", - "Accepts the following flags:", - }, - Command: func(args []string, additionalArgs []string) { - outlineFile(args, format) - }, - } -} - -func outlineFile(args []string, format string) { - if len(args) != 1 { - println(fmt.Sprintf("usage: %s", usageCommand)) - os.Exit(1) - } - - filename := args[0] - var src *os.File - if filename == stdinAlias { - src = os.Stdin - } else { - var err error - src, err = os.Open(filename) - if err != nil { - println(fmt.Sprintf("error opening file: %s", err)) - os.Exit(1) - } - } - - fset := token.NewFileSet() - - parsedSrc, err := parser.ParseFile(fset, filename, src, 0) - if err != nil { - println(fmt.Sprintf("error parsing source: %s", err)) - os.Exit(1) - } - - o, err := outline.FromASTFile(fset, parsedSrc) - if err != nil { - println(fmt.Sprintf("error creating outline: %s", err)) - os.Exit(1) - } - - var oerr error - switch format { - case "csv": - _, oerr = fmt.Print(o) - case "indent": - _, oerr = fmt.Print(o.StringIndent(indentWidth)) - case "json": - b, err := json.Marshal(o) - if err != nil { - println(fmt.Sprintf("error marshalling to json: %s", err)) - } - _, oerr = fmt.Println(string(b)) - default: - complainAndQuit(fmt.Sprintf("format %s not accepted", format)) - } - if oerr != nil { - println(fmt.Sprintf("error writing outline: %s", oerr)) - os.Exit(1) - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go deleted file mode 100644 index f3d4e99a55f..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/run_command.go +++ /dev/null @@ -1,316 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "math/rand" - "os" - "regexp" - "runtime" - "strings" - "time" - - "io/ioutil" - "path/filepath" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/ginkgo/interrupthandler" - "github.com/onsi/ginkgo/ginkgo/testrunner" - colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" - "github.com/onsi/ginkgo/types" -) - -func BuildRunCommand() *Command { - commandFlags := NewRunCommandFlags(flag.NewFlagSet("ginkgo", flag.ExitOnError)) - notifier := NewNotifier(commandFlags) - interruptHandler := interrupthandler.NewInterruptHandler() - runner := &SpecRunner{ - commandFlags: commandFlags, - notifier: notifier, - interruptHandler: interruptHandler, - suiteRunner: NewSuiteRunner(notifier, interruptHandler), - } - - return &Command{ - Name: "", - FlagSet: commandFlags.FlagSet, - UsageCommand: "ginkgo -- ", - Usage: []string{ - "Run the tests in the passed in (or the package in the current directory if left blank).", - "Any arguments after -- will be passed to the test.", - "Accepts the following flags:", - }, - Command: runner.RunSpecs, - } -} - -type SpecRunner struct { - commandFlags *RunWatchAndBuildCommandFlags - notifier *Notifier - interruptHandler *interrupthandler.InterruptHandler - suiteRunner *SuiteRunner -} - -func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) { - r.commandFlags.computeNodes() - r.notifier.VerifyNotificationsAreAvailable() - - deprecationTracker := types.NewDeprecationTracker() - - if r.commandFlags.ParallelStream && (runtime.GOOS != "windows") { - deprecationTracker.TrackDeprecation(types.Deprecation{ - Message: "--stream is deprecated and will be removed in Ginkgo 2.0", - DocLink: "removed--stream", - Version: "1.16.0", - }) - } - - if r.commandFlags.Notify { - deprecationTracker.TrackDeprecation(types.Deprecation{ - Message: "--notify is deprecated and will be removed in Ginkgo 2.0", - DocLink: "removed--notify", - Version: "1.16.0", - }) - } - - if deprecationTracker.DidTrackDeprecations() { - fmt.Fprintln(colorable.NewColorableStderr(), deprecationTracker.DeprecationsReport()) - } - - suites, skippedPackages := findSuites(args, r.commandFlags.Recurse, r.commandFlags.SkipPackage, true) - if len(skippedPackages) > 0 { - fmt.Println("Will skip:") - for _, skippedPackage := range skippedPackages { - fmt.Println(" " + skippedPackage) - } - } - - if len(skippedPackages) > 0 && len(suites) == 0 { - fmt.Println("All tests skipped! Exiting...") - os.Exit(0) - } - - if len(suites) == 0 { - complainAndQuit("Found no test suites") - } - - r.ComputeSuccinctMode(len(suites)) - - t := time.Now() - - runners := []*testrunner.TestRunner{} - for _, suite := range suites { - runners = append(runners, testrunner.New(suite, r.commandFlags.NumCPU, r.commandFlags.ParallelStream, r.commandFlags.Timeout, r.commandFlags.GoOpts, additionalArgs)) - } - - numSuites := 0 - runResult := testrunner.PassingRunResult() - if r.commandFlags.UntilItFails { - iteration := 0 - for { - r.UpdateSeed() - randomizedRunners := r.randomizeOrder(runners) - runResult, numSuites = r.suiteRunner.RunSuites(randomizedRunners, r.commandFlags.NumCompilers, r.commandFlags.KeepGoing, nil) - iteration++ - - if r.interruptHandler.WasInterrupted() { - break - } - - if runResult.Passed { - fmt.Printf("\nAll tests passed...\nWill keep running them until they fail.\nThis was attempt #%d\n%s\n", iteration, orcMessage(iteration)) - } else { - fmt.Printf("\nTests failed on attempt #%d\n\n", iteration) - break - } - } - } else { - randomizedRunners := r.randomizeOrder(runners) - runResult, numSuites = r.suiteRunner.RunSuites(randomizedRunners, r.commandFlags.NumCompilers, r.commandFlags.KeepGoing, nil) - } - - for _, runner := range runners { - runner.CleanUp() - } - - if r.isInCoverageMode() { - if r.getOutputDir() != "" { - // If coverprofile is set, combine coverages - if r.getCoverprofile() != "" { - if err := r.combineCoverprofiles(runners); err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - } else { - // Just move them - r.moveCoverprofiles(runners) - } - } - } - - fmt.Printf("\nGinkgo ran %d %s in %s\n", numSuites, pluralizedWord("suite", "suites", numSuites), time.Since(t)) - - if runResult.Passed { - if runResult.HasProgrammaticFocus && strings.TrimSpace(os.Getenv("GINKGO_EDITOR_INTEGRATION")) == "" { - fmt.Printf("Test Suite Passed\n") - fmt.Printf("Detected Programmatic Focus - setting exit status to %d\n", types.GINKGO_FOCUS_EXIT_CODE) - os.Exit(types.GINKGO_FOCUS_EXIT_CODE) - } else { - fmt.Printf("Test Suite Passed\n") - os.Exit(0) - } - } else { - fmt.Printf("Test Suite Failed\n") - emitRCAdvertisement() - os.Exit(1) - } -} - -// Moves all generated profiles to specified directory -func (r *SpecRunner) moveCoverprofiles(runners []*testrunner.TestRunner) { - for _, runner := range runners { - _, filename := filepath.Split(runner.CoverageFile) - err := os.Rename(runner.CoverageFile, filepath.Join(r.getOutputDir(), filename)) - - if err != nil { - fmt.Printf("Unable to move coverprofile %s, %v\n", runner.CoverageFile, err) - return - } - } -} - -// Combines all generated profiles in the specified directory -func (r *SpecRunner) combineCoverprofiles(runners []*testrunner.TestRunner) error { - - path, _ := filepath.Abs(r.getOutputDir()) - if !fileExists(path) { - return fmt.Errorf("Unable to create combined profile, outputdir does not exist: %s", r.getOutputDir()) - } - - fmt.Println("path is " + path) - - combined, err := os.OpenFile( - filepath.Join(path, r.getCoverprofile()), - os.O_WRONLY|os.O_CREATE, - 0666, - ) - - if err != nil { - fmt.Printf("Unable to create combined profile, %v\n", err) - return nil // non-fatal error - } - - modeRegex := regexp.MustCompile(`^mode: .*\n`) - for index, runner := range runners { - contents, err := ioutil.ReadFile(runner.CoverageFile) - - if err != nil { - fmt.Printf("Unable to read coverage file %s to combine, %v\n", runner.CoverageFile, err) - return nil // non-fatal error - } - - // remove the cover mode line from every file - // except the first one - if index > 0 { - contents = modeRegex.ReplaceAll(contents, []byte{}) - } - - _, err = combined.Write(contents) - - // Add a newline to the end of every file if missing. - if err == nil && len(contents) > 0 && contents[len(contents)-1] != '\n' { - _, err = combined.Write([]byte("\n")) - } - - if err != nil { - fmt.Printf("Unable to append to coverprofile, %v\n", err) - return nil // non-fatal error - } - } - - fmt.Println("All profiles combined") - return nil -} - -func (r *SpecRunner) isInCoverageMode() bool { - opts := r.commandFlags.GoOpts - return *opts["cover"].(*bool) || *opts["coverpkg"].(*string) != "" || *opts["covermode"].(*string) != "" -} - -func (r *SpecRunner) getCoverprofile() string { - return *r.commandFlags.GoOpts["coverprofile"].(*string) -} - -func (r *SpecRunner) getOutputDir() string { - return *r.commandFlags.GoOpts["outputdir"].(*string) -} - -func (r *SpecRunner) ComputeSuccinctMode(numSuites int) { - if config.DefaultReporterConfig.Verbose { - config.DefaultReporterConfig.Succinct = false - return - } - - if numSuites == 1 { - return - } - - if numSuites > 1 && !r.commandFlags.wasSet("succinct") { - config.DefaultReporterConfig.Succinct = true - } -} - -func (r *SpecRunner) UpdateSeed() { - if !r.commandFlags.wasSet("seed") { - config.GinkgoConfig.RandomSeed = time.Now().Unix() - } -} - -func (r *SpecRunner) randomizeOrder(runners []*testrunner.TestRunner) []*testrunner.TestRunner { - if !r.commandFlags.RandomizeSuites { - return runners - } - - if len(runners) <= 1 { - return runners - } - - randomizedRunners := make([]*testrunner.TestRunner, len(runners)) - randomizer := rand.New(rand.NewSource(config.GinkgoConfig.RandomSeed)) - permutation := randomizer.Perm(len(runners)) - for i, j := range permutation { - randomizedRunners[i] = runners[j] - } - return randomizedRunners -} - -func orcMessage(iteration int) string { - if iteration < 10 { - return "" - } else if iteration < 30 { - return []string{ - "If at first you succeed...", - "...try, try again.", - "Looking good!", - "Still good...", - "I think your tests are fine....", - "Yep, still passing", - "Oh boy, here I go testin' again!", - "Even the gophers are getting bored", - "Did you try -race?", - "Maybe you should stop now?", - "I'm getting tired...", - "What if I just made you a sandwich?", - "Hit ^C, hit ^C, please hit ^C", - "Make it stop. Please!", - "Come on! Enough is enough!", - "Dave, this conversation can serve no purpose anymore. Goodbye.", - "Just what do you think you're doing, Dave? ", - "I, Sisyphus", - "Insanity: doing the same thing over and over again and expecting different results. -Einstein", - "I guess Einstein never tried to churn butter", - }[iteration-10] + "\n" - } else { - return "No, seriously... you can probably stop now.\n" - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go b/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go deleted file mode 100644 index e0994fc3cb0..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/run_watch_and_build_command_flags.go +++ /dev/null @@ -1,169 +0,0 @@ -package main - -import ( - "flag" - "runtime" - - "time" - - "github.com/onsi/ginkgo/config" -) - -type RunWatchAndBuildCommandFlags struct { - Recurse bool - SkipPackage string - GoOpts map[string]interface{} - - //for run and watch commands - NumCPU int - NumCompilers int - ParallelStream bool - Notify bool - AfterSuiteHook string - AutoNodes bool - Timeout time.Duration - - //only for run command - KeepGoing bool - UntilItFails bool - RandomizeSuites bool - - //only for watch command - Depth int - WatchRegExp string - - FlagSet *flag.FlagSet -} - -const runMode = 1 -const watchMode = 2 -const buildMode = 3 - -func NewRunCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { - c := &RunWatchAndBuildCommandFlags{ - FlagSet: flagSet, - } - c.flags(runMode) - return c -} - -func NewWatchCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { - c := &RunWatchAndBuildCommandFlags{ - FlagSet: flagSet, - } - c.flags(watchMode) - return c -} - -func NewBuildCommandFlags(flagSet *flag.FlagSet) *RunWatchAndBuildCommandFlags { - c := &RunWatchAndBuildCommandFlags{ - FlagSet: flagSet, - } - c.flags(buildMode) - return c -} - -func (c *RunWatchAndBuildCommandFlags) wasSet(flagName string) bool { - wasSet := false - c.FlagSet.Visit(func(f *flag.Flag) { - if f.Name == flagName { - wasSet = true - } - }) - - return wasSet -} - -func (c *RunWatchAndBuildCommandFlags) computeNodes() { - if c.wasSet("nodes") { - return - } - if c.AutoNodes { - switch n := runtime.NumCPU(); { - case n <= 4: - c.NumCPU = n - default: - c.NumCPU = n - 1 - } - } -} - -func (c *RunWatchAndBuildCommandFlags) stringSlot(slot string) *string { - var opt string - c.GoOpts[slot] = &opt - return &opt -} - -func (c *RunWatchAndBuildCommandFlags) boolSlot(slot string) *bool { - var opt bool - c.GoOpts[slot] = &opt - return &opt -} - -func (c *RunWatchAndBuildCommandFlags) intSlot(slot string) *int { - var opt int - c.GoOpts[slot] = &opt - return &opt -} - -func (c *RunWatchAndBuildCommandFlags) flags(mode int) { - c.GoOpts = make(map[string]interface{}) - - onWindows := (runtime.GOOS == "windows") - - c.FlagSet.BoolVar(&(c.Recurse), "r", false, "Find and run test suites under the current directory recursively.") - c.FlagSet.BoolVar(c.boolSlot("race"), "race", false, "Run tests with race detection enabled.") - c.FlagSet.BoolVar(c.boolSlot("cover"), "cover", false, "Run tests with coverage analysis, will generate coverage profiles with the package name in the current directory.") - c.FlagSet.StringVar(c.stringSlot("coverpkg"), "coverpkg", "", "Run tests with coverage on the given external modules.") - c.FlagSet.StringVar(&(c.SkipPackage), "skipPackage", "", "A comma-separated list of package names to be skipped. If any part of the package's path matches, that package is ignored.") - c.FlagSet.StringVar(c.stringSlot("tags"), "tags", "", "A list of build tags to consider satisfied during the build.") - c.FlagSet.StringVar(c.stringSlot("gcflags"), "gcflags", "", "Arguments to pass on each go tool compile invocation.") - c.FlagSet.StringVar(c.stringSlot("covermode"), "covermode", "", "Set the mode for coverage analysis.") - c.FlagSet.BoolVar(c.boolSlot("a"), "a", false, "Force rebuilding of packages that are already up-to-date.") - c.FlagSet.BoolVar(c.boolSlot("n"), "n", false, "Have `go test` print the commands but do not run them.") - c.FlagSet.BoolVar(c.boolSlot("msan"), "msan", false, "Enable interoperation with memory sanitizer.") - c.FlagSet.BoolVar(c.boolSlot("x"), "x", false, "Have `go test` print the commands.") - c.FlagSet.BoolVar(c.boolSlot("work"), "work", false, "Print the name of the temporary work directory and do not delete it when exiting.") - c.FlagSet.StringVar(c.stringSlot("asmflags"), "asmflags", "", "Arguments to pass on each go tool asm invocation.") - c.FlagSet.StringVar(c.stringSlot("buildmode"), "buildmode", "", "Build mode to use. See 'go help buildmode' for more.") - c.FlagSet.StringVar(c.stringSlot("mod"), "mod", "", "Go module control. See 'go help modules' for more.") - c.FlagSet.StringVar(c.stringSlot("compiler"), "compiler", "", "Name of compiler to use, as in runtime.Compiler (gccgo or gc).") - c.FlagSet.StringVar(c.stringSlot("gccgoflags"), "gccgoflags", "", "Arguments to pass on each gccgo compiler/linker invocation.") - c.FlagSet.StringVar(c.stringSlot("installsuffix"), "installsuffix", "", "A suffix to use in the name of the package installation directory.") - c.FlagSet.StringVar(c.stringSlot("ldflags"), "ldflags", "", "Arguments to pass on each go tool link invocation.") - c.FlagSet.BoolVar(c.boolSlot("linkshared"), "linkshared", false, "Link against shared libraries previously created with -buildmode=shared.") - c.FlagSet.StringVar(c.stringSlot("pkgdir"), "pkgdir", "", "install and load all packages from the given dir instead of the usual locations.") - c.FlagSet.StringVar(c.stringSlot("toolexec"), "toolexec", "", "a program to use to invoke toolchain programs like vet and asm.") - c.FlagSet.IntVar(c.intSlot("blockprofilerate"), "blockprofilerate", 1, "Control the detail provided in goroutine blocking profiles by calling runtime.SetBlockProfileRate with the given value.") - c.FlagSet.StringVar(c.stringSlot("coverprofile"), "coverprofile", "", "Write a coverage profile to the specified file after all tests have passed.") - c.FlagSet.StringVar(c.stringSlot("cpuprofile"), "cpuprofile", "", "Write a CPU profile to the specified file before exiting.") - c.FlagSet.StringVar(c.stringSlot("memprofile"), "memprofile", "", "Write a memory profile to the specified file after all tests have passed.") - c.FlagSet.IntVar(c.intSlot("memprofilerate"), "memprofilerate", 0, "Enable more precise (and expensive) memory profiles by setting runtime.MemProfileRate.") - c.FlagSet.StringVar(c.stringSlot("outputdir"), "outputdir", "", "Place output files from profiling in the specified directory.") - c.FlagSet.BoolVar(c.boolSlot("requireSuite"), "requireSuite", false, "Fail if there are ginkgo tests in a directory but no test suite (missing RunSpecs)") - c.FlagSet.StringVar(c.stringSlot("vet"), "vet", "", "Configure the invocation of 'go vet' to use the comma-separated list of vet checks. If list is 'off', 'go test' does not run 'go vet' at all.") - - if mode == runMode || mode == watchMode { - config.Flags(c.FlagSet, "", false) - c.FlagSet.IntVar(&(c.NumCPU), "nodes", 1, "The number of parallel test nodes to run") - c.FlagSet.IntVar(&(c.NumCompilers), "compilers", 0, "The number of concurrent compilations to run (0 will autodetect)") - c.FlagSet.BoolVar(&(c.AutoNodes), "p", false, "Run in parallel with auto-detected number of nodes") - c.FlagSet.BoolVar(&(c.ParallelStream), "stream", onWindows, "stream parallel test output in real time: less coherent, but useful for debugging") - if !onWindows { - c.FlagSet.BoolVar(&(c.Notify), "notify", false, "Send desktop notifications when a test run completes") - } - c.FlagSet.StringVar(&(c.AfterSuiteHook), "afterSuiteHook", "", "Run a command when a suite test run completes") - c.FlagSet.DurationVar(&(c.Timeout), "timeout", 24*time.Hour, "Suite fails if it does not complete within the specified timeout") - } - - if mode == runMode { - c.FlagSet.BoolVar(&(c.KeepGoing), "keepGoing", false, "When true, failures from earlier test suites do not prevent later test suites from running") - c.FlagSet.BoolVar(&(c.UntilItFails), "untilItFails", false, "When true, Ginkgo will keep rerunning tests until a failure occurs") - c.FlagSet.BoolVar(&(c.RandomizeSuites), "randomizeSuites", false, "When true, Ginkgo will randomize the order in which test suites run") - } - - if mode == watchMode { - c.FlagSet.IntVar(&(c.Depth), "depth", 1, "Ginkgo will watch dependencies down to this depth in the dependency tree") - c.FlagSet.StringVar(&(c.WatchRegExp), "watchRegExp", `\.go$`, "Files matching this regular expression will be watched for changes") - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go b/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go deleted file mode 100644 index ab746d7e953..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/suite_runner.go +++ /dev/null @@ -1,173 +0,0 @@ -package main - -import ( - "fmt" - "runtime" - "sync" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/ginkgo/interrupthandler" - "github.com/onsi/ginkgo/ginkgo/testrunner" - "github.com/onsi/ginkgo/ginkgo/testsuite" - colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" -) - -type compilationInput struct { - runner *testrunner.TestRunner - result chan compilationOutput -} - -type compilationOutput struct { - runner *testrunner.TestRunner - err error -} - -type SuiteRunner struct { - notifier *Notifier - interruptHandler *interrupthandler.InterruptHandler -} - -func NewSuiteRunner(notifier *Notifier, interruptHandler *interrupthandler.InterruptHandler) *SuiteRunner { - return &SuiteRunner{ - notifier: notifier, - interruptHandler: interruptHandler, - } -} - -func (r *SuiteRunner) compileInParallel(runners []*testrunner.TestRunner, numCompilers int, willCompile func(suite testsuite.TestSuite)) chan compilationOutput { - //we return this to the consumer, it will return each runner in order as it compiles - compilationOutputs := make(chan compilationOutput, len(runners)) - - //an array of channels - the nth runner's compilation output is sent to the nth channel in this array - //we read from these channels in order to ensure we run the suites in order - orderedCompilationOutputs := []chan compilationOutput{} - for range runners { - orderedCompilationOutputs = append(orderedCompilationOutputs, make(chan compilationOutput, 1)) - } - - //we're going to spin up numCompilers compilers - they're going to run concurrently and will consume this channel - //we prefill the channel then close it, this ensures we compile things in the correct order - workPool := make(chan compilationInput, len(runners)) - for i, runner := range runners { - workPool <- compilationInput{runner, orderedCompilationOutputs[i]} - } - close(workPool) - - //pick a reasonable numCompilers - if numCompilers == 0 { - numCompilers = runtime.NumCPU() - } - - //a WaitGroup to help us wait for all compilers to shut down - wg := &sync.WaitGroup{} - wg.Add(numCompilers) - - //spin up the concurrent compilers - for i := 0; i < numCompilers; i++ { - go func() { - defer wg.Done() - for input := range workPool { - if r.interruptHandler.WasInterrupted() { - return - } - - if willCompile != nil { - willCompile(input.runner.Suite) - } - - //We retry because Go sometimes steps on itself when multiple compiles happen in parallel. This is ugly, but should help resolve flakiness... - var err error - retries := 0 - for retries <= 5 { - if r.interruptHandler.WasInterrupted() { - return - } - if err = input.runner.Compile(); err == nil { - break - } - retries++ - } - - input.result <- compilationOutput{input.runner, err} - } - }() - } - - //read from the compilation output channels *in order* and send them to the caller - //close the compilationOutputs channel to tell the caller we're done - go func() { - defer close(compilationOutputs) - for _, orderedCompilationOutput := range orderedCompilationOutputs { - select { - case compilationOutput := <-orderedCompilationOutput: - compilationOutputs <- compilationOutput - case <-r.interruptHandler.C: - //interrupt detected, wait for the compilers to shut down then bail - //this ensure we clean up after ourselves as we don't leave any compilation processes running - wg.Wait() - return - } - } - }() - - return compilationOutputs -} - -func (r *SuiteRunner) RunSuites(runners []*testrunner.TestRunner, numCompilers int, keepGoing bool, willCompile func(suite testsuite.TestSuite)) (testrunner.RunResult, int) { - runResult := testrunner.PassingRunResult() - - compilationOutputs := r.compileInParallel(runners, numCompilers, willCompile) - - numSuitesThatRan := 0 - suitesThatFailed := []testsuite.TestSuite{} - for compilationOutput := range compilationOutputs { - if compilationOutput.err != nil { - fmt.Print(compilationOutput.err.Error()) - } - numSuitesThatRan++ - suiteRunResult := testrunner.FailingRunResult() - if compilationOutput.err == nil { - suiteRunResult = compilationOutput.runner.Run() - } - r.notifier.SendSuiteCompletionNotification(compilationOutput.runner.Suite, suiteRunResult.Passed) - r.notifier.RunCommand(compilationOutput.runner.Suite, suiteRunResult.Passed) - runResult = runResult.Merge(suiteRunResult) - if !suiteRunResult.Passed { - suitesThatFailed = append(suitesThatFailed, compilationOutput.runner.Suite) - if !keepGoing { - break - } - } - if numSuitesThatRan < len(runners) && !config.DefaultReporterConfig.Succinct { - fmt.Println("") - } - } - - if keepGoing && !runResult.Passed { - r.listFailedSuites(suitesThatFailed) - } - - return runResult, numSuitesThatRan -} - -func (r *SuiteRunner) listFailedSuites(suitesThatFailed []testsuite.TestSuite) { - fmt.Println("") - fmt.Println("There were failures detected in the following suites:") - - maxPackageNameLength := 0 - for _, suite := range suitesThatFailed { - if len(suite.PackageName) > maxPackageNameLength { - maxPackageNameLength = len(suite.PackageName) - } - } - - packageNameFormatter := fmt.Sprintf("%%%ds", maxPackageNameLength) - - for _, suite := range suitesThatFailed { - if config.DefaultReporterConfig.NoColor { - fmt.Printf("\t"+packageNameFormatter+" %s\n", suite.PackageName, suite.Path) - } else { - fmt.Fprintf(colorable.NewColorableStdout(), "\t%s"+packageNameFormatter+"%s %s%s%s\n", redColor, suite.PackageName, defaultStyle, lightGrayColor, suite.Path, defaultStyle) - } - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go deleted file mode 100644 index 3b1a238c2c4..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build go1.10 - -package testrunner - -var ( - buildArgs = []string{"test", "-c"} -) diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go deleted file mode 100644 index 14d70dbcc52..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/build_args_old.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build !go1.10 - -package testrunner - -var ( - buildArgs = []string{"test", "-c", "-i"} -) diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go deleted file mode 100644 index a73a6e37919..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/log_writer.go +++ /dev/null @@ -1,52 +0,0 @@ -package testrunner - -import ( - "bytes" - "fmt" - "io" - "log" - "strings" - "sync" -) - -type logWriter struct { - buffer *bytes.Buffer - lock *sync.Mutex - log *log.Logger -} - -func newLogWriter(target io.Writer, node int) *logWriter { - return &logWriter{ - buffer: &bytes.Buffer{}, - lock: &sync.Mutex{}, - log: log.New(target, fmt.Sprintf("[%d] ", node), 0), - } -} - -func (w *logWriter) Write(data []byte) (n int, err error) { - w.lock.Lock() - defer w.lock.Unlock() - - w.buffer.Write(data) - contents := w.buffer.String() - - lines := strings.Split(contents, "\n") - for _, line := range lines[0 : len(lines)-1] { - w.log.Println(line) - } - - w.buffer.Reset() - w.buffer.Write([]byte(lines[len(lines)-1])) - return len(data), nil -} - -func (w *logWriter) Close() error { - w.lock.Lock() - defer w.lock.Unlock() - - if w.buffer.Len() > 0 { - w.log.Println(w.buffer.String()) - } - - return nil -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go deleted file mode 100644 index 5d472acb8d2..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/run_result.go +++ /dev/null @@ -1,27 +0,0 @@ -package testrunner - -type RunResult struct { - Passed bool - HasProgrammaticFocus bool -} - -func PassingRunResult() RunResult { - return RunResult{ - Passed: true, - HasProgrammaticFocus: false, - } -} - -func FailingRunResult() RunResult { - return RunResult{ - Passed: false, - HasProgrammaticFocus: false, - } -} - -func (r RunResult) Merge(o RunResult) RunResult { - return RunResult{ - Passed: r.Passed && o.Passed, - HasProgrammaticFocus: r.HasProgrammaticFocus || o.HasProgrammaticFocus, - } -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go b/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go deleted file mode 100644 index 66c0f06f628..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testrunner/test_runner.go +++ /dev/null @@ -1,554 +0,0 @@ -package testrunner - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "syscall" - "time" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/ginkgo/testsuite" - "github.com/onsi/ginkgo/internal/remote" - "github.com/onsi/ginkgo/reporters/stenographer" - colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" - "github.com/onsi/ginkgo/types" -) - -type TestRunner struct { - Suite testsuite.TestSuite - - compiled bool - compilationTargetPath string - - numCPU int - parallelStream bool - timeout time.Duration - goOpts map[string]interface{} - additionalArgs []string - stderr *bytes.Buffer - - CoverageFile string -} - -func New(suite testsuite.TestSuite, numCPU int, parallelStream bool, timeout time.Duration, goOpts map[string]interface{}, additionalArgs []string) *TestRunner { - runner := &TestRunner{ - Suite: suite, - numCPU: numCPU, - parallelStream: parallelStream, - goOpts: goOpts, - additionalArgs: additionalArgs, - timeout: timeout, - stderr: new(bytes.Buffer), - } - - if !suite.Precompiled { - runner.compilationTargetPath, _ = filepath.Abs(filepath.Join(suite.Path, suite.PackageName+".test")) - } - - return runner -} - -func (t *TestRunner) Compile() error { - return t.CompileTo(t.compilationTargetPath) -} - -func (t *TestRunner) BuildArgs(path string) []string { - args := make([]string, len(buildArgs), len(buildArgs)+3) - copy(args, buildArgs) - args = append(args, "-o", path, t.Suite.Path) - - if t.getCoverMode() != "" { - args = append(args, "-cover", fmt.Sprintf("-covermode=%s", t.getCoverMode())) - } else { - if t.shouldCover() || t.getCoverPackage() != "" { - args = append(args, "-cover", "-covermode=atomic") - } - } - - boolOpts := []string{ - "a", - "n", - "msan", - "race", - "x", - "work", - "linkshared", - } - - for _, opt := range boolOpts { - if s, found := t.goOpts[opt].(*bool); found && *s { - args = append(args, fmt.Sprintf("-%s", opt)) - } - } - - intOpts := []string{ - "memprofilerate", - "blockprofilerate", - } - - for _, opt := range intOpts { - if s, found := t.goOpts[opt].(*int); found { - args = append(args, fmt.Sprintf("-%s=%d", opt, *s)) - } - } - - stringOpts := []string{ - "asmflags", - "buildmode", - "compiler", - "gccgoflags", - "installsuffix", - "ldflags", - "pkgdir", - "toolexec", - "coverprofile", - "cpuprofile", - "memprofile", - "outputdir", - "coverpkg", - "tags", - "gcflags", - "vet", - "mod", - } - - for _, opt := range stringOpts { - if s, found := t.goOpts[opt].(*string); found && *s != "" { - args = append(args, fmt.Sprintf("-%s=%s", opt, *s)) - } - } - return args -} - -func (t *TestRunner) CompileTo(path string) error { - if t.compiled { - return nil - } - - if t.Suite.Precompiled { - return nil - } - - args := t.BuildArgs(path) - cmd := exec.Command("go", args...) - - output, err := cmd.CombinedOutput() - - if err != nil { - if len(output) > 0 { - return fmt.Errorf("Failed to compile %s:\n\n%s", t.Suite.PackageName, output) - } - return fmt.Errorf("Failed to compile %s", t.Suite.PackageName) - } - - if len(output) > 0 { - fmt.Println(string(output)) - } - - if !fileExists(path) { - compiledFile := t.Suite.PackageName + ".test" - if fileExists(compiledFile) { - // seems like we are on an old go version that does not support the -o flag on go test - // move the compiled test file to the desired location by hand - err = os.Rename(compiledFile, path) - if err != nil { - // We cannot move the file, perhaps because the source and destination - // are on different partitions. We can copy the file, however. - err = copyFile(compiledFile, path) - if err != nil { - return fmt.Errorf("Failed to copy compiled file: %s", err) - } - } - } else { - return fmt.Errorf("Failed to compile %s: output file %q could not be found", t.Suite.PackageName, path) - } - } - - t.compiled = true - - return nil -} - -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil || !os.IsNotExist(err) -} - -// copyFile copies the contents of the file named src to the file named -// by dst. The file will be created if it does not already exist. If the -// destination file exists, all it's contents will be replaced by the contents -// of the source file. -func copyFile(src, dst string) error { - srcInfo, err := os.Stat(src) - if err != nil { - return err - } - mode := srcInfo.Mode() - - in, err := os.Open(src) - if err != nil { - return err - } - - defer in.Close() - - out, err := os.Create(dst) - if err != nil { - return err - } - - defer func() { - closeErr := out.Close() - if err == nil { - err = closeErr - } - }() - - _, err = io.Copy(out, in) - if err != nil { - return err - } - - err = out.Sync() - if err != nil { - return err - } - - return out.Chmod(mode) -} - -func (t *TestRunner) Run() RunResult { - if t.Suite.IsGinkgo { - if t.numCPU > 1 { - if t.parallelStream { - return t.runAndStreamParallelGinkgoSuite() - } else { - return t.runParallelGinkgoSuite() - } - } else { - return t.runSerialGinkgoSuite() - } - } else { - return t.runGoTestSuite() - } -} - -func (t *TestRunner) CleanUp() { - if t.Suite.Precompiled { - return - } - os.Remove(t.compilationTargetPath) -} - -func (t *TestRunner) runSerialGinkgoSuite() RunResult { - ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) - return t.run(t.cmd(ginkgoArgs, os.Stdout, 1), nil) -} - -func (t *TestRunner) runGoTestSuite() RunResult { - return t.run(t.cmd([]string{"-test.v"}, os.Stdout, 1), nil) -} - -func (t *TestRunner) runAndStreamParallelGinkgoSuite() RunResult { - completions := make(chan RunResult) - writers := make([]*logWriter, t.numCPU) - - server, err := remote.NewServer(t.numCPU) - if err != nil { - panic("Failed to start parallel spec server") - } - - server.Start() - defer server.Close() - - for cpu := 0; cpu < t.numCPU; cpu++ { - config.GinkgoConfig.ParallelNode = cpu + 1 - config.GinkgoConfig.ParallelTotal = t.numCPU - config.GinkgoConfig.SyncHost = server.Address() - - ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) - - writers[cpu] = newLogWriter(os.Stdout, cpu+1) - - cmd := t.cmd(ginkgoArgs, writers[cpu], cpu+1) - - server.RegisterAlive(cpu+1, func() bool { - if cmd.ProcessState == nil { - return true - } - return !cmd.ProcessState.Exited() - }) - - go t.run(cmd, completions) - } - - res := PassingRunResult() - - for cpu := 0; cpu < t.numCPU; cpu++ { - res = res.Merge(<-completions) - } - - for _, writer := range writers { - writer.Close() - } - - os.Stdout.Sync() - - if t.shouldCombineCoverprofiles() { - t.combineCoverprofiles() - } - - return res -} - -func (t *TestRunner) runParallelGinkgoSuite() RunResult { - result := make(chan bool) - completions := make(chan RunResult) - writers := make([]*logWriter, t.numCPU) - reports := make([]*bytes.Buffer, t.numCPU) - - stenographer := stenographer.New(!config.DefaultReporterConfig.NoColor, config.GinkgoConfig.FlakeAttempts > 1, colorable.NewColorableStdout()) - aggregator := remote.NewAggregator(t.numCPU, result, config.DefaultReporterConfig, stenographer) - - server, err := remote.NewServer(t.numCPU) - if err != nil { - panic("Failed to start parallel spec server") - } - server.RegisterReporters(aggregator) - server.Start() - defer server.Close() - - for cpu := 0; cpu < t.numCPU; cpu++ { - config.GinkgoConfig.ParallelNode = cpu + 1 - config.GinkgoConfig.ParallelTotal = t.numCPU - config.GinkgoConfig.SyncHost = server.Address() - config.GinkgoConfig.StreamHost = server.Address() - - ginkgoArgs := config.BuildFlagArgs("ginkgo", config.GinkgoConfig, config.DefaultReporterConfig) - - reports[cpu] = &bytes.Buffer{} - writers[cpu] = newLogWriter(reports[cpu], cpu+1) - - cmd := t.cmd(ginkgoArgs, writers[cpu], cpu+1) - - server.RegisterAlive(cpu+1, func() bool { - if cmd.ProcessState == nil { - return true - } - return !cmd.ProcessState.Exited() - }) - - go t.run(cmd, completions) - } - - res := PassingRunResult() - - for cpu := 0; cpu < t.numCPU; cpu++ { - res = res.Merge(<-completions) - } - - //all test processes are done, at this point - //we should be able to wait for the aggregator to tell us that it's done - - select { - case <-result: - fmt.Println("") - case <-time.After(time.Second): - //the aggregator never got back to us! something must have gone wrong - fmt.Println(` - ------------------------------------------------------------------- - | | - | Ginkgo timed out waiting for all parallel nodes to report back! | - | | - -------------------------------------------------------------------`) - fmt.Println("\n", t.Suite.PackageName, "timed out. path:", t.Suite.Path) - os.Stdout.Sync() - - for _, writer := range writers { - writer.Close() - } - - for _, report := range reports { - fmt.Print(report.String()) - } - - os.Stdout.Sync() - } - - if t.shouldCombineCoverprofiles() { - t.combineCoverprofiles() - } - - return res -} - -const CoverProfileSuffix = ".coverprofile" - -func (t *TestRunner) cmd(ginkgoArgs []string, stream io.Writer, node int) *exec.Cmd { - args := []string{"--test.timeout=" + t.timeout.String()} - - coverProfile := t.getCoverProfile() - - if t.shouldCombineCoverprofiles() { - - testCoverProfile := "--test.coverprofile=" - - coverageFile := "" - // Set default name for coverage results - if coverProfile == "" { - coverageFile = t.Suite.PackageName + CoverProfileSuffix - } else { - coverageFile = coverProfile - } - - testCoverProfile += coverageFile - - t.CoverageFile = filepath.Join(t.Suite.Path, coverageFile) - - if t.numCPU > 1 { - testCoverProfile = fmt.Sprintf("%s.%d", testCoverProfile, node) - } - args = append(args, testCoverProfile) - } - - args = append(args, ginkgoArgs...) - args = append(args, t.additionalArgs...) - - path := t.compilationTargetPath - if t.Suite.Precompiled { - path, _ = filepath.Abs(filepath.Join(t.Suite.Path, fmt.Sprintf("%s.test", t.Suite.PackageName))) - } - - cmd := exec.Command(path, args...) - - cmd.Dir = t.Suite.Path - cmd.Stderr = io.MultiWriter(stream, t.stderr) - cmd.Stdout = stream - - return cmd -} - -func (t *TestRunner) shouldCover() bool { - return *t.goOpts["cover"].(*bool) -} - -func (t *TestRunner) shouldRequireSuite() bool { - return *t.goOpts["requireSuite"].(*bool) -} - -func (t *TestRunner) getCoverProfile() string { - return *t.goOpts["coverprofile"].(*string) -} - -func (t *TestRunner) getCoverPackage() string { - return *t.goOpts["coverpkg"].(*string) -} - -func (t *TestRunner) getCoverMode() string { - return *t.goOpts["covermode"].(*string) -} - -func (t *TestRunner) shouldCombineCoverprofiles() bool { - return t.shouldCover() || t.getCoverPackage() != "" || t.getCoverMode() != "" -} - -func (t *TestRunner) run(cmd *exec.Cmd, completions chan RunResult) RunResult { - var res RunResult - - defer func() { - if completions != nil { - completions <- res - } - }() - - err := cmd.Start() - if err != nil { - fmt.Printf("Failed to run test suite!\n\t%s", err.Error()) - return res - } - - cmd.Wait() - - exitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() - res.Passed = (exitStatus == 0) || (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) - res.HasProgrammaticFocus = (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) - - if strings.Contains(t.stderr.String(), "warning: no tests to run") { - if t.shouldRequireSuite() { - res.Passed = false - } - fmt.Fprintf(os.Stderr, `Found no test suites, did you forget to run "ginkgo bootstrap"?`) - } - - return res -} - -func (t *TestRunner) combineCoverprofiles() { - profiles := []string{} - - coverProfile := t.getCoverProfile() - - for cpu := 1; cpu <= t.numCPU; cpu++ { - var coverFile string - if coverProfile == "" { - coverFile = fmt.Sprintf("%s%s.%d", t.Suite.PackageName, CoverProfileSuffix, cpu) - } else { - coverFile = fmt.Sprintf("%s.%d", coverProfile, cpu) - } - - coverFile = filepath.Join(t.Suite.Path, coverFile) - coverProfile, err := ioutil.ReadFile(coverFile) - os.Remove(coverFile) - - if err == nil { - profiles = append(profiles, string(coverProfile)) - } - } - - if len(profiles) != t.numCPU { - return - } - - lines := map[string]int{} - lineOrder := []string{} - for i, coverProfile := range profiles { - for _, line := range strings.Split(coverProfile, "\n")[1:] { - if len(line) == 0 { - continue - } - components := strings.Split(line, " ") - count, _ := strconv.Atoi(components[len(components)-1]) - prefix := strings.Join(components[0:len(components)-1], " ") - lines[prefix] += count - if i == 0 { - lineOrder = append(lineOrder, prefix) - } - } - } - - output := []string{"mode: atomic"} - for _, line := range lineOrder { - output = append(output, fmt.Sprintf("%s %d", line, lines[line])) - } - finalOutput := strings.Join(output, "\n") - - finalFilename := "" - - if coverProfile != "" { - finalFilename = coverProfile - } else { - finalFilename = fmt.Sprintf("%s%s", t.Suite.PackageName, CoverProfileSuffix) - } - - coverageFilepath := filepath.Join(t.Suite.Path, finalFilename) - ioutil.WriteFile(coverageFilepath, []byte(finalOutput), 0666) - - t.CoverageFile = coverageFilepath -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go deleted file mode 100644 index 9de8c2bb4e6..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/test_suite.go +++ /dev/null @@ -1,115 +0,0 @@ -package testsuite - -import ( - "errors" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "strings" -) - -type TestSuite struct { - Path string - PackageName string - IsGinkgo bool - Precompiled bool -} - -func PrecompiledTestSuite(path string) (TestSuite, error) { - info, err := os.Stat(path) - if err != nil { - return TestSuite{}, err - } - - if info.IsDir() { - return TestSuite{}, errors.New("this is a directory, not a file") - } - - if filepath.Ext(path) != ".test" { - return TestSuite{}, errors.New("this is not a .test binary") - } - - if info.Mode()&0111 == 0 { - return TestSuite{}, errors.New("this is not executable") - } - - dir := relPath(filepath.Dir(path)) - packageName := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)) - - return TestSuite{ - Path: dir, - PackageName: packageName, - IsGinkgo: true, - Precompiled: true, - }, nil -} - -func SuitesInDir(dir string, recurse bool) []TestSuite { - suites := []TestSuite{} - - if vendorExperimentCheck(dir) { - return suites - } - - files, _ := ioutil.ReadDir(dir) - re := regexp.MustCompile(`^[^._].*_test\.go$`) - for _, file := range files { - if !file.IsDir() && re.Match([]byte(file.Name())) { - suites = append(suites, New(dir, files)) - break - } - } - - if recurse { - re = regexp.MustCompile(`^[._]`) - for _, file := range files { - if file.IsDir() && !re.Match([]byte(file.Name())) { - suites = append(suites, SuitesInDir(dir+"/"+file.Name(), recurse)...) - } - } - } - - return suites -} - -func relPath(dir string) string { - dir, _ = filepath.Abs(dir) - cwd, _ := os.Getwd() - dir, _ = filepath.Rel(cwd, filepath.Clean(dir)) - - if string(dir[0]) != "." { - dir = "." + string(filepath.Separator) + dir - } - - return dir -} - -func New(dir string, files []os.FileInfo) TestSuite { - return TestSuite{ - Path: relPath(dir), - PackageName: packageNameForSuite(dir), - IsGinkgo: filesHaveGinkgoSuite(dir, files), - } -} - -func packageNameForSuite(dir string) string { - path, _ := filepath.Abs(dir) - return filepath.Base(path) -} - -func filesHaveGinkgoSuite(dir string, files []os.FileInfo) bool { - reTestFile := regexp.MustCompile(`_test\.go$`) - reGinkgo := regexp.MustCompile(`package ginkgo|\/ginkgo"`) - - for _, file := range files { - if !file.IsDir() && reTestFile.Match([]byte(file.Name())) { - contents, _ := ioutil.ReadFile(dir + "/" + file.Name()) - if reGinkgo.Match(contents) { - return true - } - } - } - - return false -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go deleted file mode 100644 index 75f827a121b..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go15.go +++ /dev/null @@ -1,16 +0,0 @@ -// +build !go1.6 - -package testsuite - -import ( - "os" - "path" -) - -// "This change will only be enabled if the go command is run with -// GO15VENDOREXPERIMENT=1 in its environment." -// c.f. the vendor-experiment proposal https://goo.gl/2ucMeC -func vendorExperimentCheck(dir string) bool { - vendorExperiment := os.Getenv("GO15VENDOREXPERIMENT") - return vendorExperiment == "1" && path.Base(dir) == "vendor" -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go b/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go deleted file mode 100644 index 596e5e5c180..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/testsuite/vendor_check_go16.go +++ /dev/null @@ -1,15 +0,0 @@ -// +build go1.6 - -package testsuite - -import ( - "os" - "path" -) - -// in 1.6 the vendor directory became the default go behaviour, so now -// check if its disabled. -func vendorExperimentCheck(dir string) bool { - vendorExperiment := os.Getenv("GO15VENDOREXPERIMENT") - return vendorExperiment != "0" && path.Base(dir) == "vendor" -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go deleted file mode 100644 index a5b68c216fb..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/version_command.go +++ /dev/null @@ -1,25 +0,0 @@ -package main - -import ( - "flag" - "fmt" - - "github.com/onsi/ginkgo/config" -) - -func BuildVersionCommand() *Command { - return &Command{ - Name: "version", - FlagSet: flag.NewFlagSet("version", flag.ExitOnError), - UsageCommand: "ginkgo version", - Usage: []string{ - "Print Ginkgo's version", - }, - Command: printVersion, - } -} - -func printVersion([]string, []string) { - fmt.Printf("Ginkgo Version %s\n", config.VERSION) - emitRCAdvertisement() -} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go b/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go deleted file mode 100644 index a6ef053c851..00000000000 --- a/vendor/github.com/onsi/ginkgo/ginkgo/watch_command.go +++ /dev/null @@ -1,175 +0,0 @@ -package main - -import ( - "flag" - "fmt" - "regexp" - "time" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/ginkgo/interrupthandler" - "github.com/onsi/ginkgo/ginkgo/testrunner" - "github.com/onsi/ginkgo/ginkgo/testsuite" - "github.com/onsi/ginkgo/ginkgo/watch" - colorable "github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable" -) - -func BuildWatchCommand() *Command { - commandFlags := NewWatchCommandFlags(flag.NewFlagSet("watch", flag.ExitOnError)) - interruptHandler := interrupthandler.NewInterruptHandler() - notifier := NewNotifier(commandFlags) - watcher := &SpecWatcher{ - commandFlags: commandFlags, - notifier: notifier, - interruptHandler: interruptHandler, - suiteRunner: NewSuiteRunner(notifier, interruptHandler), - } - - return &Command{ - Name: "watch", - FlagSet: commandFlags.FlagSet, - UsageCommand: "ginkgo watch -- ", - Usage: []string{ - "Watches the tests in the passed in and runs them when changes occur.", - "Any arguments after -- will be passed to the test.", - }, - Command: watcher.WatchSpecs, - SuppressFlagDocumentation: true, - FlagDocSubstitute: []string{ - "Accepts all the flags that the ginkgo command accepts except for --keepGoing and --untilItFails", - }, - } -} - -type SpecWatcher struct { - commandFlags *RunWatchAndBuildCommandFlags - notifier *Notifier - interruptHandler *interrupthandler.InterruptHandler - suiteRunner *SuiteRunner -} - -func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { - w.commandFlags.computeNodes() - w.notifier.VerifyNotificationsAreAvailable() - - w.WatchSuites(args, additionalArgs) -} - -func (w *SpecWatcher) runnersForSuites(suites []testsuite.TestSuite, additionalArgs []string) []*testrunner.TestRunner { - runners := []*testrunner.TestRunner{} - - for _, suite := range suites { - runners = append(runners, testrunner.New(suite, w.commandFlags.NumCPU, w.commandFlags.ParallelStream, w.commandFlags.Timeout, w.commandFlags.GoOpts, additionalArgs)) - } - - return runners -} - -func (w *SpecWatcher) WatchSuites(args []string, additionalArgs []string) { - suites, _ := findSuites(args, w.commandFlags.Recurse, w.commandFlags.SkipPackage, false) - - if len(suites) == 0 { - complainAndQuit("Found no test suites") - } - - fmt.Printf("Identified %d test %s. Locating dependencies to a depth of %d (this may take a while)...\n", len(suites), pluralizedWord("suite", "suites", len(suites)), w.commandFlags.Depth) - deltaTracker := watch.NewDeltaTracker(w.commandFlags.Depth, regexp.MustCompile(w.commandFlags.WatchRegExp)) - delta, errors := deltaTracker.Delta(suites) - - fmt.Printf("Watching %d %s:\n", len(delta.NewSuites), pluralizedWord("suite", "suites", len(delta.NewSuites))) - for _, suite := range delta.NewSuites { - fmt.Println(" " + suite.Description()) - } - - for suite, err := range errors { - fmt.Printf("Failed to watch %s: %s\n", suite.PackageName, err) - } - - if len(suites) == 1 { - runners := w.runnersForSuites(suites, additionalArgs) - w.suiteRunner.RunSuites(runners, w.commandFlags.NumCompilers, true, nil) - runners[0].CleanUp() - } - - ticker := time.NewTicker(time.Second) - - for { - select { - case <-ticker.C: - suites, _ := findSuites(args, w.commandFlags.Recurse, w.commandFlags.SkipPackage, false) - delta, _ := deltaTracker.Delta(suites) - coloredStream := colorable.NewColorableStdout() - - suitesToRun := []testsuite.TestSuite{} - - if len(delta.NewSuites) > 0 { - fmt.Fprintf(coloredStream, greenColor+"Detected %d new %s:\n"+defaultStyle, len(delta.NewSuites), pluralizedWord("suite", "suites", len(delta.NewSuites))) - for _, suite := range delta.NewSuites { - suitesToRun = append(suitesToRun, suite.Suite) - fmt.Fprintln(coloredStream, " "+suite.Description()) - } - } - - modifiedSuites := delta.ModifiedSuites() - if len(modifiedSuites) > 0 { - fmt.Fprintln(coloredStream, greenColor+"\nDetected changes in:"+defaultStyle) - for _, pkg := range delta.ModifiedPackages { - fmt.Fprintln(coloredStream, " "+pkg) - } - fmt.Fprintf(coloredStream, greenColor+"Will run %d %s:\n"+defaultStyle, len(modifiedSuites), pluralizedWord("suite", "suites", len(modifiedSuites))) - for _, suite := range modifiedSuites { - suitesToRun = append(suitesToRun, suite.Suite) - fmt.Fprintln(coloredStream, " "+suite.Description()) - } - fmt.Fprintln(coloredStream, "") - } - - if len(suitesToRun) > 0 { - w.UpdateSeed() - w.ComputeSuccinctMode(len(suitesToRun)) - runners := w.runnersForSuites(suitesToRun, additionalArgs) - result, _ := w.suiteRunner.RunSuites(runners, w.commandFlags.NumCompilers, true, func(suite testsuite.TestSuite) { - deltaTracker.WillRun(suite) - }) - for _, runner := range runners { - runner.CleanUp() - } - if !w.interruptHandler.WasInterrupted() { - color := redColor - if result.Passed { - color = greenColor - } - fmt.Fprintln(coloredStream, color+"\nDone. Resuming watch..."+defaultStyle) - } - } - - case <-w.interruptHandler.C: - return - } - } -} - -func (w *SpecWatcher) ComputeSuccinctMode(numSuites int) { - if config.DefaultReporterConfig.Verbose { - config.DefaultReporterConfig.Succinct = false - return - } - - if w.commandFlags.wasSet("succinct") { - return - } - - if numSuites == 1 { - config.DefaultReporterConfig.Succinct = false - } - - if numSuites > 1 { - config.DefaultReporterConfig.Succinct = true - } -} - -func (w *SpecWatcher) UpdateSeed() { - if !w.commandFlags.wasSet("seed") { - config.GinkgoConfig.RandomSeed = time.Now().Unix() - } -} diff --git a/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go b/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go deleted file mode 100644 index aa89d6cba8a..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/codelocation/code_location.go +++ /dev/null @@ -1,48 +0,0 @@ -package codelocation - -import ( - "regexp" - "runtime" - "runtime/debug" - "strings" - - "github.com/onsi/ginkgo/types" -) - -func New(skip int) types.CodeLocation { - _, file, line, _ := runtime.Caller(skip + 1) - stackTrace := PruneStack(string(debug.Stack()), skip+1) - return types.CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} -} - -// PruneStack removes references to functions that are internal to Ginkgo -// and the Go runtime from a stack string and a certain number of stack entries -// at the beginning of the stack. The stack string has the format -// as returned by runtime/debug.Stack. The leading goroutine information is -// optional and always removed if present. Beware that runtime/debug.Stack -// adds itself as first entry, so typically skip must be >= 1 to remove that -// entry. -func PruneStack(fullStackTrace string, skip int) string { - stack := strings.Split(fullStackTrace, "\n") - // Ensure that the even entries are the method names and the - // the odd entries the source code information. - if len(stack) > 0 && strings.HasPrefix(stack[0], "goroutine ") { - // Ignore "goroutine 29 [running]:" line. - stack = stack[1:] - } - // The "+1" is for skipping over the initial entry, which is - // runtime/debug.Stack() itself. - if len(stack) > 2*(skip+1) { - stack = stack[2*(skip+1):] - } - prunedStack := []string{} - re := regexp.MustCompile(`\/ginkgo\/|\/pkg\/testing\/|\/pkg\/runtime\/`) - for i := 0; i < len(stack)/2; i++ { - // We filter out based on the source code file name. - if !re.Match([]byte(stack[i*2+1])) { - prunedStack = append(prunedStack, stack[i*2]) - prunedStack = append(prunedStack, stack[i*2+1]) - } - } - return strings.Join(prunedStack, "\n") -} diff --git a/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go b/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go deleted file mode 100644 index 0737746dcfe..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/containernode/container_node.go +++ /dev/null @@ -1,151 +0,0 @@ -package containernode - -import ( - "math/rand" - "sort" - - "github.com/onsi/ginkgo/internal/leafnodes" - "github.com/onsi/ginkgo/types" -) - -type subjectOrContainerNode struct { - containerNode *ContainerNode - subjectNode leafnodes.SubjectNode -} - -func (n subjectOrContainerNode) text() string { - if n.containerNode != nil { - return n.containerNode.Text() - } else { - return n.subjectNode.Text() - } -} - -type CollatedNodes struct { - Containers []*ContainerNode - Subject leafnodes.SubjectNode -} - -type ContainerNode struct { - text string - flag types.FlagType - codeLocation types.CodeLocation - - setupNodes []leafnodes.BasicNode - subjectAndContainerNodes []subjectOrContainerNode -} - -func New(text string, flag types.FlagType, codeLocation types.CodeLocation) *ContainerNode { - return &ContainerNode{ - text: text, - flag: flag, - codeLocation: codeLocation, - } -} - -func (container *ContainerNode) Shuffle(r *rand.Rand) { - sort.Sort(container) - permutation := r.Perm(len(container.subjectAndContainerNodes)) - shuffledNodes := make([]subjectOrContainerNode, len(container.subjectAndContainerNodes)) - for i, j := range permutation { - shuffledNodes[i] = container.subjectAndContainerNodes[j] - } - container.subjectAndContainerNodes = shuffledNodes -} - -func (node *ContainerNode) BackPropagateProgrammaticFocus() bool { - if node.flag == types.FlagTypePending { - return false - } - - shouldUnfocus := false - for _, subjectOrContainerNode := range node.subjectAndContainerNodes { - if subjectOrContainerNode.containerNode != nil { - shouldUnfocus = subjectOrContainerNode.containerNode.BackPropagateProgrammaticFocus() || shouldUnfocus - } else { - shouldUnfocus = (subjectOrContainerNode.subjectNode.Flag() == types.FlagTypeFocused) || shouldUnfocus - } - } - - if shouldUnfocus { - if node.flag == types.FlagTypeFocused { - node.flag = types.FlagTypeNone - } - return true - } - - return node.flag == types.FlagTypeFocused -} - -func (node *ContainerNode) Collate() []CollatedNodes { - return node.collate([]*ContainerNode{}) -} - -func (node *ContainerNode) collate(enclosingContainers []*ContainerNode) []CollatedNodes { - collated := make([]CollatedNodes, 0) - - containers := make([]*ContainerNode, len(enclosingContainers)) - copy(containers, enclosingContainers) - containers = append(containers, node) - - for _, subjectOrContainer := range node.subjectAndContainerNodes { - if subjectOrContainer.containerNode != nil { - collated = append(collated, subjectOrContainer.containerNode.collate(containers)...) - } else { - collated = append(collated, CollatedNodes{ - Containers: containers, - Subject: subjectOrContainer.subjectNode, - }) - } - } - - return collated -} - -func (node *ContainerNode) PushContainerNode(container *ContainerNode) { - node.subjectAndContainerNodes = append(node.subjectAndContainerNodes, subjectOrContainerNode{containerNode: container}) -} - -func (node *ContainerNode) PushSubjectNode(subject leafnodes.SubjectNode) { - node.subjectAndContainerNodes = append(node.subjectAndContainerNodes, subjectOrContainerNode{subjectNode: subject}) -} - -func (node *ContainerNode) PushSetupNode(setupNode leafnodes.BasicNode) { - node.setupNodes = append(node.setupNodes, setupNode) -} - -func (node *ContainerNode) SetupNodesOfType(nodeType types.SpecComponentType) []leafnodes.BasicNode { - nodes := []leafnodes.BasicNode{} - for _, setupNode := range node.setupNodes { - if setupNode.Type() == nodeType { - nodes = append(nodes, setupNode) - } - } - return nodes -} - -func (node *ContainerNode) Text() string { - return node.text -} - -func (node *ContainerNode) CodeLocation() types.CodeLocation { - return node.codeLocation -} - -func (node *ContainerNode) Flag() types.FlagType { - return node.flag -} - -//sort.Interface - -func (node *ContainerNode) Len() int { - return len(node.subjectAndContainerNodes) -} - -func (node *ContainerNode) Less(i, j int) bool { - return node.subjectAndContainerNodes[i].text() < node.subjectAndContainerNodes[j].text() -} - -func (node *ContainerNode) Swap(i, j int) { - node.subjectAndContainerNodes[i], node.subjectAndContainerNodes[j] = node.subjectAndContainerNodes[j], node.subjectAndContainerNodes[i] -} diff --git a/vendor/github.com/onsi/ginkgo/internal/failer/failer.go b/vendor/github.com/onsi/ginkgo/internal/failer/failer.go deleted file mode 100644 index 678ea2514a6..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/failer/failer.go +++ /dev/null @@ -1,92 +0,0 @@ -package failer - -import ( - "fmt" - "sync" - - "github.com/onsi/ginkgo/types" -) - -type Failer struct { - lock *sync.Mutex - failure types.SpecFailure - state types.SpecState -} - -func New() *Failer { - return &Failer{ - lock: &sync.Mutex{}, - state: types.SpecStatePassed, - } -} - -func (f *Failer) Panic(location types.CodeLocation, forwardedPanic interface{}) { - f.lock.Lock() - defer f.lock.Unlock() - - if f.state == types.SpecStatePassed { - f.state = types.SpecStatePanicked - f.failure = types.SpecFailure{ - Message: "Test Panicked", - Location: location, - ForwardedPanic: fmt.Sprintf("%v", forwardedPanic), - } - } -} - -func (f *Failer) Timeout(location types.CodeLocation) { - f.lock.Lock() - defer f.lock.Unlock() - - if f.state == types.SpecStatePassed { - f.state = types.SpecStateTimedOut - f.failure = types.SpecFailure{ - Message: "Timed out", - Location: location, - } - } -} - -func (f *Failer) Fail(message string, location types.CodeLocation) { - f.lock.Lock() - defer f.lock.Unlock() - - if f.state == types.SpecStatePassed { - f.state = types.SpecStateFailed - f.failure = types.SpecFailure{ - Message: message, - Location: location, - } - } -} - -func (f *Failer) Drain(componentType types.SpecComponentType, componentIndex int, componentCodeLocation types.CodeLocation) (types.SpecFailure, types.SpecState) { - f.lock.Lock() - defer f.lock.Unlock() - - failure := f.failure - outcome := f.state - if outcome != types.SpecStatePassed { - failure.ComponentType = componentType - failure.ComponentIndex = componentIndex - failure.ComponentCodeLocation = componentCodeLocation - } - - f.state = types.SpecStatePassed - f.failure = types.SpecFailure{} - - return failure, outcome -} - -func (f *Failer) Skip(message string, location types.CodeLocation) { - f.lock.Lock() - defer f.lock.Unlock() - - if f.state == types.SpecStatePassed { - f.state = types.SpecStateSkipped - f.failure = types.SpecFailure{ - Message: message, - Location: location, - } - } -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go deleted file mode 100644 index 393901e11c3..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/benchmarker.go +++ /dev/null @@ -1,103 +0,0 @@ -package leafnodes - -import ( - "math" - "time" - - "sync" - - "github.com/onsi/ginkgo/types" -) - -type benchmarker struct { - mu sync.Mutex - measurements map[string]*types.SpecMeasurement - orderCounter int -} - -func newBenchmarker() *benchmarker { - return &benchmarker{ - measurements: make(map[string]*types.SpecMeasurement), - } -} - -func (b *benchmarker) Time(name string, body func(), info ...interface{}) (elapsedTime time.Duration) { - t := time.Now() - body() - elapsedTime = time.Since(t) - - b.mu.Lock() - defer b.mu.Unlock() - measurement := b.getMeasurement(name, "Fastest Time", "Slowest Time", "Average Time", "s", 3, info...) - measurement.Results = append(measurement.Results, elapsedTime.Seconds()) - - return -} - -func (b *benchmarker) RecordValue(name string, value float64, info ...interface{}) { - b.mu.Lock() - measurement := b.getMeasurement(name, "Smallest", " Largest", " Average", "", 3, info...) - defer b.mu.Unlock() - measurement.Results = append(measurement.Results, value) -} - -func (b *benchmarker) RecordValueWithPrecision(name string, value float64, units string, precision int, info ...interface{}) { - b.mu.Lock() - measurement := b.getMeasurement(name, "Smallest", " Largest", " Average", units, precision, info...) - defer b.mu.Unlock() - measurement.Results = append(measurement.Results, value) -} - -func (b *benchmarker) getMeasurement(name string, smallestLabel string, largestLabel string, averageLabel string, units string, precision int, info ...interface{}) *types.SpecMeasurement { - measurement, ok := b.measurements[name] - if !ok { - var computedInfo interface{} - computedInfo = nil - if len(info) > 0 { - computedInfo = info[0] - } - measurement = &types.SpecMeasurement{ - Name: name, - Info: computedInfo, - Order: b.orderCounter, - SmallestLabel: smallestLabel, - LargestLabel: largestLabel, - AverageLabel: averageLabel, - Units: units, - Precision: precision, - Results: make([]float64, 0), - } - b.measurements[name] = measurement - b.orderCounter++ - } - - return measurement -} - -func (b *benchmarker) measurementsReport() map[string]*types.SpecMeasurement { - b.mu.Lock() - defer b.mu.Unlock() - for _, measurement := range b.measurements { - measurement.Smallest = math.MaxFloat64 - measurement.Largest = -math.MaxFloat64 - sum := float64(0) - sumOfSquares := float64(0) - - for _, result := range measurement.Results { - if result > measurement.Largest { - measurement.Largest = result - } - if result < measurement.Smallest { - measurement.Smallest = result - } - sum += result - sumOfSquares += result * result - } - - n := float64(len(measurement.Results)) - measurement.Average = sum / n - measurement.StdDeviation = math.Sqrt(sumOfSquares/n - (sum/n)*(sum/n)) - } - - return b.measurements -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go deleted file mode 100644 index 8c3902d601c..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/interfaces.go +++ /dev/null @@ -1,19 +0,0 @@ -package leafnodes - -import ( - "github.com/onsi/ginkgo/types" -) - -type BasicNode interface { - Type() types.SpecComponentType - Run() (types.SpecState, types.SpecFailure) - CodeLocation() types.CodeLocation -} - -type SubjectNode interface { - BasicNode - - Text() string - Flag() types.FlagType - Samples() int -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go deleted file mode 100644 index 6eded7b763e..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/it_node.go +++ /dev/null @@ -1,47 +0,0 @@ -package leafnodes - -import ( - "time" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type ItNode struct { - runner *runner - - flag types.FlagType - text string -} - -func NewItNode(text string, body interface{}, flag types.FlagType, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *ItNode { - return &ItNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeIt, componentIndex), - flag: flag, - text: text, - } -} - -func (node *ItNode) Run() (outcome types.SpecState, failure types.SpecFailure) { - return node.runner.run() -} - -func (node *ItNode) Type() types.SpecComponentType { - return types.SpecComponentTypeIt -} - -func (node *ItNode) Text() string { - return node.text -} - -func (node *ItNode) Flag() types.FlagType { - return node.flag -} - -func (node *ItNode) CodeLocation() types.CodeLocation { - return node.runner.codeLocation -} - -func (node *ItNode) Samples() int { - return 1 -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go deleted file mode 100644 index 3ab9a6d5524..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/measure_node.go +++ /dev/null @@ -1,62 +0,0 @@ -package leafnodes - -import ( - "reflect" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type MeasureNode struct { - runner *runner - - text string - flag types.FlagType - samples int - benchmarker *benchmarker -} - -func NewMeasureNode(text string, body interface{}, flag types.FlagType, codeLocation types.CodeLocation, samples int, failer *failer.Failer, componentIndex int) *MeasureNode { - benchmarker := newBenchmarker() - - wrappedBody := func() { - reflect.ValueOf(body).Call([]reflect.Value{reflect.ValueOf(benchmarker)}) - } - - return &MeasureNode{ - runner: newRunner(wrappedBody, codeLocation, 0, failer, types.SpecComponentTypeMeasure, componentIndex), - - text: text, - flag: flag, - samples: samples, - benchmarker: benchmarker, - } -} - -func (node *MeasureNode) Run() (outcome types.SpecState, failure types.SpecFailure) { - return node.runner.run() -} - -func (node *MeasureNode) MeasurementsReport() map[string]*types.SpecMeasurement { - return node.benchmarker.measurementsReport() -} - -func (node *MeasureNode) Type() types.SpecComponentType { - return types.SpecComponentTypeMeasure -} - -func (node *MeasureNode) Text() string { - return node.text -} - -func (node *MeasureNode) Flag() types.FlagType { - return node.flag -} - -func (node *MeasureNode) CodeLocation() types.CodeLocation { - return node.runner.codeLocation -} - -func (node *MeasureNode) Samples() int { - return node.samples -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go deleted file mode 100644 index 16cb66c3e49..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/runner.go +++ /dev/null @@ -1,117 +0,0 @@ -package leafnodes - -import ( - "fmt" - "reflect" - "time" - - "github.com/onsi/ginkgo/internal/codelocation" - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type runner struct { - isAsync bool - asyncFunc func(chan<- interface{}) - syncFunc func() - codeLocation types.CodeLocation - timeoutThreshold time.Duration - nodeType types.SpecComponentType - componentIndex int - failer *failer.Failer -} - -func newRunner(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, nodeType types.SpecComponentType, componentIndex int) *runner { - bodyType := reflect.TypeOf(body) - if bodyType.Kind() != reflect.Func { - panic(fmt.Sprintf("Expected a function but got something else at %v", codeLocation)) - } - - runner := &runner{ - codeLocation: codeLocation, - timeoutThreshold: timeout, - failer: failer, - nodeType: nodeType, - componentIndex: componentIndex, - } - - switch bodyType.NumIn() { - case 0: - runner.syncFunc = body.(func()) - return runner - case 1: - if !(bodyType.In(0).Kind() == reflect.Chan && bodyType.In(0).Elem().Kind() == reflect.Interface) { - panic(fmt.Sprintf("Must pass a Done channel to function at %v", codeLocation)) - } - - wrappedBody := func(done chan<- interface{}) { - bodyValue := reflect.ValueOf(body) - bodyValue.Call([]reflect.Value{reflect.ValueOf(done)}) - } - - runner.isAsync = true - runner.asyncFunc = wrappedBody - return runner - } - - panic(fmt.Sprintf("Too many arguments to function at %v", codeLocation)) -} - -func (r *runner) run() (outcome types.SpecState, failure types.SpecFailure) { - if r.isAsync { - return r.runAsync() - } else { - return r.runSync() - } -} - -func (r *runner) runAsync() (outcome types.SpecState, failure types.SpecFailure) { - done := make(chan interface{}, 1) - - go func() { - finished := false - - defer func() { - if e := recover(); e != nil || !finished { - r.failer.Panic(codelocation.New(2), e) - select { - case <-done: - break - default: - close(done) - } - } - }() - - r.asyncFunc(done) - finished = true - }() - - // If this goroutine gets no CPU time before the select block, - // the <-done case may complete even if the test took longer than the timeoutThreshold. - // This can cause flaky behaviour, but we haven't seen it in the wild. - select { - case <-done: - case <-time.After(r.timeoutThreshold): - r.failer.Timeout(r.codeLocation) - } - - failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation) - return -} -func (r *runner) runSync() (outcome types.SpecState, failure types.SpecFailure) { - finished := false - - defer func() { - if e := recover(); e != nil || !finished { - r.failer.Panic(codelocation.New(2), e) - } - - failure, outcome = r.failer.Drain(r.nodeType, r.componentIndex, r.codeLocation) - }() - - r.syncFunc() - finished = true - - return -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go deleted file mode 100644 index e3e9cb7c588..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/setup_nodes.go +++ /dev/null @@ -1,48 +0,0 @@ -package leafnodes - -import ( - "time" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type SetupNode struct { - runner *runner -} - -func (node *SetupNode) Run() (outcome types.SpecState, failure types.SpecFailure) { - return node.runner.run() -} - -func (node *SetupNode) Type() types.SpecComponentType { - return node.runner.nodeType -} - -func (node *SetupNode) CodeLocation() types.CodeLocation { - return node.runner.codeLocation -} - -func NewBeforeEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { - return &SetupNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeBeforeEach, componentIndex), - } -} - -func NewAfterEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { - return &SetupNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeAfterEach, componentIndex), - } -} - -func NewJustBeforeEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { - return &SetupNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeJustBeforeEach, componentIndex), - } -} - -func NewJustAfterEachNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer, componentIndex int) *SetupNode { - return &SetupNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeJustAfterEach, componentIndex), - } -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go deleted file mode 100644 index 80f16ed7861..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/suite_nodes.go +++ /dev/null @@ -1,55 +0,0 @@ -package leafnodes - -import ( - "time" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type SuiteNode interface { - Run(parallelNode int, parallelTotal int, syncHost string) bool - Passed() bool - Summary() *types.SetupSummary -} - -type simpleSuiteNode struct { - runner *runner - outcome types.SpecState - failure types.SpecFailure - runTime time.Duration -} - -func (node *simpleSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { - t := time.Now() - node.outcome, node.failure = node.runner.run() - node.runTime = time.Since(t) - - return node.outcome == types.SpecStatePassed -} - -func (node *simpleSuiteNode) Passed() bool { - return node.outcome == types.SpecStatePassed -} - -func (node *simpleSuiteNode) Summary() *types.SetupSummary { - return &types.SetupSummary{ - ComponentType: node.runner.nodeType, - CodeLocation: node.runner.codeLocation, - State: node.outcome, - RunTime: node.runTime, - Failure: node.failure, - } -} - -func NewBeforeSuiteNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { - return &simpleSuiteNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0), - } -} - -func NewAfterSuiteNode(body interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { - return &simpleSuiteNode{ - runner: newRunner(body, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), - } -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go deleted file mode 100644 index a721d0cf7f2..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_after_suite_node.go +++ /dev/null @@ -1,90 +0,0 @@ -package leafnodes - -import ( - "encoding/json" - "io/ioutil" - "net/http" - "time" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type synchronizedAfterSuiteNode struct { - runnerA *runner - runnerB *runner - - outcome types.SpecState - failure types.SpecFailure - runTime time.Duration -} - -func NewSynchronizedAfterSuiteNode(bodyA interface{}, bodyB interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { - return &synchronizedAfterSuiteNode{ - runnerA: newRunner(bodyA, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), - runnerB: newRunner(bodyB, codeLocation, timeout, failer, types.SpecComponentTypeAfterSuite, 0), - } -} - -func (node *synchronizedAfterSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { - node.outcome, node.failure = node.runnerA.run() - - if parallelNode == 1 { - if parallelTotal > 1 { - node.waitUntilOtherNodesAreDone(syncHost) - } - - outcome, failure := node.runnerB.run() - - if node.outcome == types.SpecStatePassed { - node.outcome, node.failure = outcome, failure - } - } - - return node.outcome == types.SpecStatePassed -} - -func (node *synchronizedAfterSuiteNode) Passed() bool { - return node.outcome == types.SpecStatePassed -} - -func (node *synchronizedAfterSuiteNode) Summary() *types.SetupSummary { - return &types.SetupSummary{ - ComponentType: node.runnerA.nodeType, - CodeLocation: node.runnerA.codeLocation, - State: node.outcome, - RunTime: node.runTime, - Failure: node.failure, - } -} - -func (node *synchronizedAfterSuiteNode) waitUntilOtherNodesAreDone(syncHost string) { - for { - if node.canRun(syncHost) { - return - } - - time.Sleep(50 * time.Millisecond) - } -} - -func (node *synchronizedAfterSuiteNode) canRun(syncHost string) bool { - resp, err := http.Get(syncHost + "/RemoteAfterSuiteData") - if err != nil || resp.StatusCode != http.StatusOK { - return false - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return false - } - resp.Body.Close() - - afterSuiteData := types.RemoteAfterSuiteData{} - err = json.Unmarshal(body, &afterSuiteData) - if err != nil { - return false - } - - return afterSuiteData.CanRun -} diff --git a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go b/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go deleted file mode 100644 index d5c88931940..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/leafnodes/synchronized_before_suite_node.go +++ /dev/null @@ -1,181 +0,0 @@ -package leafnodes - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - "reflect" - "time" - - "github.com/onsi/ginkgo/internal/failer" - "github.com/onsi/ginkgo/types" -) - -type synchronizedBeforeSuiteNode struct { - runnerA *runner - runnerB *runner - - data []byte - - outcome types.SpecState - failure types.SpecFailure - runTime time.Duration -} - -func NewSynchronizedBeforeSuiteNode(bodyA interface{}, bodyB interface{}, codeLocation types.CodeLocation, timeout time.Duration, failer *failer.Failer) SuiteNode { - node := &synchronizedBeforeSuiteNode{} - - node.runnerA = newRunner(node.wrapA(bodyA), codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0) - node.runnerB = newRunner(node.wrapB(bodyB), codeLocation, timeout, failer, types.SpecComponentTypeBeforeSuite, 0) - - return node -} - -func (node *synchronizedBeforeSuiteNode) Run(parallelNode int, parallelTotal int, syncHost string) bool { - t := time.Now() - defer func() { - node.runTime = time.Since(t) - }() - - if parallelNode == 1 { - node.outcome, node.failure = node.runA(parallelTotal, syncHost) - } else { - node.outcome, node.failure = node.waitForA(syncHost) - } - - if node.outcome != types.SpecStatePassed { - return false - } - node.outcome, node.failure = node.runnerB.run() - - return node.outcome == types.SpecStatePassed -} - -func (node *synchronizedBeforeSuiteNode) runA(parallelTotal int, syncHost string) (types.SpecState, types.SpecFailure) { - outcome, failure := node.runnerA.run() - - if parallelTotal > 1 { - state := types.RemoteBeforeSuiteStatePassed - if outcome != types.SpecStatePassed { - state = types.RemoteBeforeSuiteStateFailed - } - json := (types.RemoteBeforeSuiteData{ - Data: node.data, - State: state, - }).ToJSON() - http.Post(syncHost+"/BeforeSuiteState", "application/json", bytes.NewBuffer(json)) - } - - return outcome, failure -} - -func (node *synchronizedBeforeSuiteNode) waitForA(syncHost string) (types.SpecState, types.SpecFailure) { - failure := func(message string) types.SpecFailure { - return types.SpecFailure{ - Message: message, - Location: node.runnerA.codeLocation, - ComponentType: node.runnerA.nodeType, - ComponentIndex: node.runnerA.componentIndex, - ComponentCodeLocation: node.runnerA.codeLocation, - } - } - for { - resp, err := http.Get(syncHost + "/BeforeSuiteState") - if err != nil || resp.StatusCode != http.StatusOK { - return types.SpecStateFailed, failure("Failed to fetch BeforeSuite state") - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return types.SpecStateFailed, failure("Failed to read BeforeSuite state") - } - resp.Body.Close() - - beforeSuiteData := types.RemoteBeforeSuiteData{} - err = json.Unmarshal(body, &beforeSuiteData) - if err != nil { - return types.SpecStateFailed, failure("Failed to decode BeforeSuite state") - } - - switch beforeSuiteData.State { - case types.RemoteBeforeSuiteStatePassed: - node.data = beforeSuiteData.Data - return types.SpecStatePassed, types.SpecFailure{} - case types.RemoteBeforeSuiteStateFailed: - return types.SpecStateFailed, failure("BeforeSuite on Node 1 failed") - case types.RemoteBeforeSuiteStateDisappeared: - return types.SpecStateFailed, failure("Node 1 disappeared before completing BeforeSuite") - } - - time.Sleep(50 * time.Millisecond) - } -} - -func (node *synchronizedBeforeSuiteNode) Passed() bool { - return node.outcome == types.SpecStatePassed -} - -func (node *synchronizedBeforeSuiteNode) Summary() *types.SetupSummary { - return &types.SetupSummary{ - ComponentType: node.runnerA.nodeType, - CodeLocation: node.runnerA.codeLocation, - State: node.outcome, - RunTime: node.runTime, - Failure: node.failure, - } -} - -func (node *synchronizedBeforeSuiteNode) wrapA(bodyA interface{}) interface{} { - typeA := reflect.TypeOf(bodyA) - if typeA.Kind() != reflect.Func { - panic("SynchronizedBeforeSuite expects a function as its first argument") - } - - takesNothing := typeA.NumIn() == 0 - takesADoneChannel := typeA.NumIn() == 1 && typeA.In(0).Kind() == reflect.Chan && typeA.In(0).Elem().Kind() == reflect.Interface - returnsBytes := typeA.NumOut() == 1 && typeA.Out(0).Kind() == reflect.Slice && typeA.Out(0).Elem().Kind() == reflect.Uint8 - - if !((takesNothing || takesADoneChannel) && returnsBytes) { - panic("SynchronizedBeforeSuite's first argument should be a function that returns []byte and either takes no arguments or takes a Done channel.") - } - - if takesADoneChannel { - return func(done chan<- interface{}) { - out := reflect.ValueOf(bodyA).Call([]reflect.Value{reflect.ValueOf(done)}) - node.data = out[0].Interface().([]byte) - } - } - - return func() { - out := reflect.ValueOf(bodyA).Call([]reflect.Value{}) - node.data = out[0].Interface().([]byte) - } -} - -func (node *synchronizedBeforeSuiteNode) wrapB(bodyB interface{}) interface{} { - typeB := reflect.TypeOf(bodyB) - if typeB.Kind() != reflect.Func { - panic("SynchronizedBeforeSuite expects a function as its second argument") - } - - returnsNothing := typeB.NumOut() == 0 - takesBytesOnly := typeB.NumIn() == 1 && typeB.In(0).Kind() == reflect.Slice && typeB.In(0).Elem().Kind() == reflect.Uint8 - takesBytesAndDone := typeB.NumIn() == 2 && - typeB.In(0).Kind() == reflect.Slice && typeB.In(0).Elem().Kind() == reflect.Uint8 && - typeB.In(1).Kind() == reflect.Chan && typeB.In(1).Elem().Kind() == reflect.Interface - - if !((takesBytesOnly || takesBytesAndDone) && returnsNothing) { - panic("SynchronizedBeforeSuite's second argument should be a function that returns nothing and either takes []byte or ([]byte, Done)") - } - - if takesBytesAndDone { - return func(done chan<- interface{}) { - reflect.ValueOf(bodyB).Call([]reflect.Value{reflect.ValueOf(node.data), reflect.ValueOf(done)}) - } - } - - return func() { - reflect.ValueOf(bodyB).Call([]reflect.Value{reflect.ValueOf(node.data)}) - } -} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go b/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go deleted file mode 100644 index 992437d9e0b..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/aggregator.go +++ /dev/null @@ -1,249 +0,0 @@ -/* - -Aggregator is a reporter used by the Ginkgo CLI to aggregate and present parallel test output -coherently as tests complete. You shouldn't need to use this in your code. To run tests in parallel: - - ginkgo -nodes=N - -where N is the number of nodes you desire. -*/ -package remote - -import ( - "time" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/reporters/stenographer" - "github.com/onsi/ginkgo/types" -) - -type configAndSuite struct { - config config.GinkgoConfigType - summary *types.SuiteSummary -} - -type Aggregator struct { - nodeCount int - config config.DefaultReporterConfigType - stenographer stenographer.Stenographer - result chan bool - - suiteBeginnings chan configAndSuite - aggregatedSuiteBeginnings []configAndSuite - - beforeSuites chan *types.SetupSummary - aggregatedBeforeSuites []*types.SetupSummary - - afterSuites chan *types.SetupSummary - aggregatedAfterSuites []*types.SetupSummary - - specCompletions chan *types.SpecSummary - completedSpecs []*types.SpecSummary - - suiteEndings chan *types.SuiteSummary - aggregatedSuiteEndings []*types.SuiteSummary - specs []*types.SpecSummary - - startTime time.Time -} - -func NewAggregator(nodeCount int, result chan bool, config config.DefaultReporterConfigType, stenographer stenographer.Stenographer) *Aggregator { - aggregator := &Aggregator{ - nodeCount: nodeCount, - result: result, - config: config, - stenographer: stenographer, - - suiteBeginnings: make(chan configAndSuite), - beforeSuites: make(chan *types.SetupSummary), - afterSuites: make(chan *types.SetupSummary), - specCompletions: make(chan *types.SpecSummary), - suiteEndings: make(chan *types.SuiteSummary), - } - - go aggregator.mux() - - return aggregator -} - -func (aggregator *Aggregator) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { - aggregator.suiteBeginnings <- configAndSuite{config, summary} -} - -func (aggregator *Aggregator) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - aggregator.beforeSuites <- setupSummary -} - -func (aggregator *Aggregator) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - aggregator.afterSuites <- setupSummary -} - -func (aggregator *Aggregator) SpecWillRun(specSummary *types.SpecSummary) { - //noop -} - -func (aggregator *Aggregator) SpecDidComplete(specSummary *types.SpecSummary) { - aggregator.specCompletions <- specSummary -} - -func (aggregator *Aggregator) SpecSuiteDidEnd(summary *types.SuiteSummary) { - aggregator.suiteEndings <- summary -} - -func (aggregator *Aggregator) mux() { -loop: - for { - select { - case configAndSuite := <-aggregator.suiteBeginnings: - aggregator.registerSuiteBeginning(configAndSuite) - case setupSummary := <-aggregator.beforeSuites: - aggregator.registerBeforeSuite(setupSummary) - case setupSummary := <-aggregator.afterSuites: - aggregator.registerAfterSuite(setupSummary) - case specSummary := <-aggregator.specCompletions: - aggregator.registerSpecCompletion(specSummary) - case suite := <-aggregator.suiteEndings: - finished, passed := aggregator.registerSuiteEnding(suite) - if finished { - aggregator.result <- passed - break loop - } - } - } -} - -func (aggregator *Aggregator) registerSuiteBeginning(configAndSuite configAndSuite) { - aggregator.aggregatedSuiteBeginnings = append(aggregator.aggregatedSuiteBeginnings, configAndSuite) - - if len(aggregator.aggregatedSuiteBeginnings) == 1 { - aggregator.startTime = time.Now() - } - - if len(aggregator.aggregatedSuiteBeginnings) != aggregator.nodeCount { - return - } - - aggregator.stenographer.AnnounceSuite(configAndSuite.summary.SuiteDescription, configAndSuite.config.RandomSeed, configAndSuite.config.RandomizeAllSpecs, aggregator.config.Succinct) - - totalNumberOfSpecs := 0 - if len(aggregator.aggregatedSuiteBeginnings) > 0 { - totalNumberOfSpecs = configAndSuite.summary.NumberOfSpecsBeforeParallelization - } - - aggregator.stenographer.AnnounceTotalNumberOfSpecs(totalNumberOfSpecs, aggregator.config.Succinct) - aggregator.stenographer.AnnounceAggregatedParallelRun(aggregator.nodeCount, aggregator.config.Succinct) - aggregator.flushCompletedSpecs() -} - -func (aggregator *Aggregator) registerBeforeSuite(setupSummary *types.SetupSummary) { - aggregator.aggregatedBeforeSuites = append(aggregator.aggregatedBeforeSuites, setupSummary) - aggregator.flushCompletedSpecs() -} - -func (aggregator *Aggregator) registerAfterSuite(setupSummary *types.SetupSummary) { - aggregator.aggregatedAfterSuites = append(aggregator.aggregatedAfterSuites, setupSummary) - aggregator.flushCompletedSpecs() -} - -func (aggregator *Aggregator) registerSpecCompletion(specSummary *types.SpecSummary) { - aggregator.completedSpecs = append(aggregator.completedSpecs, specSummary) - aggregator.specs = append(aggregator.specs, specSummary) - aggregator.flushCompletedSpecs() -} - -func (aggregator *Aggregator) flushCompletedSpecs() { - if len(aggregator.aggregatedSuiteBeginnings) != aggregator.nodeCount { - return - } - - for _, setupSummary := range aggregator.aggregatedBeforeSuites { - aggregator.announceBeforeSuite(setupSummary) - } - - for _, specSummary := range aggregator.completedSpecs { - aggregator.announceSpec(specSummary) - } - - for _, setupSummary := range aggregator.aggregatedAfterSuites { - aggregator.announceAfterSuite(setupSummary) - } - - aggregator.aggregatedBeforeSuites = []*types.SetupSummary{} - aggregator.completedSpecs = []*types.SpecSummary{} - aggregator.aggregatedAfterSuites = []*types.SetupSummary{} -} - -func (aggregator *Aggregator) announceBeforeSuite(setupSummary *types.SetupSummary) { - aggregator.stenographer.AnnounceCapturedOutput(setupSummary.CapturedOutput) - if setupSummary.State != types.SpecStatePassed { - aggregator.stenographer.AnnounceBeforeSuiteFailure(setupSummary, aggregator.config.Succinct, aggregator.config.FullTrace) - } -} - -func (aggregator *Aggregator) announceAfterSuite(setupSummary *types.SetupSummary) { - aggregator.stenographer.AnnounceCapturedOutput(setupSummary.CapturedOutput) - if setupSummary.State != types.SpecStatePassed { - aggregator.stenographer.AnnounceAfterSuiteFailure(setupSummary, aggregator.config.Succinct, aggregator.config.FullTrace) - } -} - -func (aggregator *Aggregator) announceSpec(specSummary *types.SpecSummary) { - if aggregator.config.Verbose && specSummary.State != types.SpecStatePending && specSummary.State != types.SpecStateSkipped { - aggregator.stenographer.AnnounceSpecWillRun(specSummary) - } - - aggregator.stenographer.AnnounceCapturedOutput(specSummary.CapturedOutput) - - switch specSummary.State { - case types.SpecStatePassed: - if specSummary.IsMeasurement { - aggregator.stenographer.AnnounceSuccessfulMeasurement(specSummary, aggregator.config.Succinct) - } else if specSummary.RunTime.Seconds() >= aggregator.config.SlowSpecThreshold { - aggregator.stenographer.AnnounceSuccessfulSlowSpec(specSummary, aggregator.config.Succinct) - } else { - aggregator.stenographer.AnnounceSuccessfulSpec(specSummary) - } - - case types.SpecStatePending: - aggregator.stenographer.AnnouncePendingSpec(specSummary, aggregator.config.NoisyPendings && !aggregator.config.Succinct) - case types.SpecStateSkipped: - aggregator.stenographer.AnnounceSkippedSpec(specSummary, aggregator.config.Succinct || !aggregator.config.NoisySkippings, aggregator.config.FullTrace) - case types.SpecStateTimedOut: - aggregator.stenographer.AnnounceSpecTimedOut(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) - case types.SpecStatePanicked: - aggregator.stenographer.AnnounceSpecPanicked(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) - case types.SpecStateFailed: - aggregator.stenographer.AnnounceSpecFailed(specSummary, aggregator.config.Succinct, aggregator.config.FullTrace) - } -} - -func (aggregator *Aggregator) registerSuiteEnding(suite *types.SuiteSummary) (finished bool, passed bool) { - aggregator.aggregatedSuiteEndings = append(aggregator.aggregatedSuiteEndings, suite) - if len(aggregator.aggregatedSuiteEndings) < aggregator.nodeCount { - return false, false - } - - aggregatedSuiteSummary := &types.SuiteSummary{} - aggregatedSuiteSummary.SuiteSucceeded = true - - for _, suiteSummary := range aggregator.aggregatedSuiteEndings { - if !suiteSummary.SuiteSucceeded { - aggregatedSuiteSummary.SuiteSucceeded = false - } - - aggregatedSuiteSummary.NumberOfSpecsThatWillBeRun += suiteSummary.NumberOfSpecsThatWillBeRun - aggregatedSuiteSummary.NumberOfTotalSpecs += suiteSummary.NumberOfTotalSpecs - aggregatedSuiteSummary.NumberOfPassedSpecs += suiteSummary.NumberOfPassedSpecs - aggregatedSuiteSummary.NumberOfFailedSpecs += suiteSummary.NumberOfFailedSpecs - aggregatedSuiteSummary.NumberOfPendingSpecs += suiteSummary.NumberOfPendingSpecs - aggregatedSuiteSummary.NumberOfSkippedSpecs += suiteSummary.NumberOfSkippedSpecs - aggregatedSuiteSummary.NumberOfFlakedSpecs += suiteSummary.NumberOfFlakedSpecs - } - - aggregatedSuiteSummary.RunTime = time.Since(aggregator.startTime) - - aggregator.stenographer.SummarizeFailures(aggregator.specs) - aggregator.stenographer.AnnounceSpecRunCompletion(aggregatedSuiteSummary, aggregator.config.Succinct) - - return true, aggregatedSuiteSummary.SuiteSucceeded -} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go b/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go deleted file mode 100644 index 284bc62e5e8..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/forwarding_reporter.go +++ /dev/null @@ -1,147 +0,0 @@ -package remote - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - - "github.com/onsi/ginkgo/internal/writer" - "github.com/onsi/ginkgo/reporters" - "github.com/onsi/ginkgo/reporters/stenographer" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/types" -) - -//An interface to net/http's client to allow the injection of fakes under test -type Poster interface { - Post(url string, bodyType string, body io.Reader) (resp *http.Response, err error) -} - -/* -The ForwardingReporter is a Ginkgo reporter that forwards information to -a Ginkgo remote server. - -When streaming parallel test output, this repoter is automatically installed by Ginkgo. - -This is accomplished by passing in the GINKGO_REMOTE_REPORTING_SERVER environment variable to `go test`, the Ginkgo test runner -detects this environment variable (which should contain the host of the server) and automatically installs a ForwardingReporter -in place of Ginkgo's DefaultReporter. -*/ - -type ForwardingReporter struct { - serverHost string - poster Poster - outputInterceptor OutputInterceptor - debugMode bool - debugFile *os.File - nestedReporter *reporters.DefaultReporter -} - -func NewForwardingReporter(config config.DefaultReporterConfigType, serverHost string, poster Poster, outputInterceptor OutputInterceptor, ginkgoWriter *writer.Writer, debugFile string) *ForwardingReporter { - reporter := &ForwardingReporter{ - serverHost: serverHost, - poster: poster, - outputInterceptor: outputInterceptor, - } - - if debugFile != "" { - var err error - reporter.debugMode = true - reporter.debugFile, err = os.Create(debugFile) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - - if !config.Verbose { - //if verbose is true then the GinkgoWriter emits to stdout. Don't _also_ redirect GinkgoWriter output as that will result in duplication. - ginkgoWriter.AndRedirectTo(reporter.debugFile) - } - outputInterceptor.StreamTo(reporter.debugFile) //This is not working - - stenographer := stenographer.New(false, true, reporter.debugFile) - config.Succinct = false - config.Verbose = true - config.FullTrace = true - reporter.nestedReporter = reporters.NewDefaultReporter(config, stenographer) - } - - return reporter -} - -func (reporter *ForwardingReporter) post(path string, data interface{}) { - encoded, _ := json.Marshal(data) - buffer := bytes.NewBuffer(encoded) - reporter.poster.Post(reporter.serverHost+path, "application/json", buffer) -} - -func (reporter *ForwardingReporter) SpecSuiteWillBegin(conf config.GinkgoConfigType, summary *types.SuiteSummary) { - data := struct { - Config config.GinkgoConfigType `json:"config"` - Summary *types.SuiteSummary `json:"suite-summary"` - }{ - conf, - summary, - } - - reporter.outputInterceptor.StartInterceptingOutput() - if reporter.debugMode { - reporter.nestedReporter.SpecSuiteWillBegin(conf, summary) - reporter.debugFile.Sync() - } - reporter.post("/SpecSuiteWillBegin", data) -} - -func (reporter *ForwardingReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() - reporter.outputInterceptor.StartInterceptingOutput() - setupSummary.CapturedOutput = output - if reporter.debugMode { - reporter.nestedReporter.BeforeSuiteDidRun(setupSummary) - reporter.debugFile.Sync() - } - reporter.post("/BeforeSuiteDidRun", setupSummary) -} - -func (reporter *ForwardingReporter) SpecWillRun(specSummary *types.SpecSummary) { - if reporter.debugMode { - reporter.nestedReporter.SpecWillRun(specSummary) - reporter.debugFile.Sync() - } - reporter.post("/SpecWillRun", specSummary) -} - -func (reporter *ForwardingReporter) SpecDidComplete(specSummary *types.SpecSummary) { - output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() - reporter.outputInterceptor.StartInterceptingOutput() - specSummary.CapturedOutput = output - if reporter.debugMode { - reporter.nestedReporter.SpecDidComplete(specSummary) - reporter.debugFile.Sync() - } - reporter.post("/SpecDidComplete", specSummary) -} - -func (reporter *ForwardingReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - output, _ := reporter.outputInterceptor.StopInterceptingAndReturnOutput() - reporter.outputInterceptor.StartInterceptingOutput() - setupSummary.CapturedOutput = output - if reporter.debugMode { - reporter.nestedReporter.AfterSuiteDidRun(setupSummary) - reporter.debugFile.Sync() - } - reporter.post("/AfterSuiteDidRun", setupSummary) -} - -func (reporter *ForwardingReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { - reporter.outputInterceptor.StopInterceptingAndReturnOutput() - if reporter.debugMode { - reporter.nestedReporter.SpecSuiteDidEnd(summary) - reporter.debugFile.Sync() - } - reporter.post("/SpecSuiteDidEnd", summary) -} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go deleted file mode 100644 index 5154abe87d5..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor.go +++ /dev/null @@ -1,13 +0,0 @@ -package remote - -import "os" - -/* -The OutputInterceptor is used by the ForwardingReporter to -intercept and capture all stdin and stderr output during a test run. -*/ -type OutputInterceptor interface { - StartInterceptingOutput() error - StopInterceptingAndReturnOutput() (string, error) - StreamTo(*os.File) -} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go deleted file mode 100644 index 774967db66b..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_unix.go +++ /dev/null @@ -1,82 +0,0 @@ -// +build freebsd openbsd netbsd dragonfly darwin linux solaris - -package remote - -import ( - "errors" - "io/ioutil" - "os" - - "github.com/nxadm/tail" - "golang.org/x/sys/unix" -) - -func NewOutputInterceptor() OutputInterceptor { - return &outputInterceptor{} -} - -type outputInterceptor struct { - redirectFile *os.File - streamTarget *os.File - intercepting bool - tailer *tail.Tail - doneTailing chan bool -} - -func (interceptor *outputInterceptor) StartInterceptingOutput() error { - if interceptor.intercepting { - return errors.New("Already intercepting output!") - } - interceptor.intercepting = true - - var err error - - interceptor.redirectFile, err = ioutil.TempFile("", "ginkgo-output") - if err != nil { - return err - } - - // This might call Dup3 if the dup2 syscall is not available, e.g. on - // linux/arm64 or linux/riscv64 - unix.Dup2(int(interceptor.redirectFile.Fd()), 1) - unix.Dup2(int(interceptor.redirectFile.Fd()), 2) - - if interceptor.streamTarget != nil { - interceptor.tailer, _ = tail.TailFile(interceptor.redirectFile.Name(), tail.Config{Follow: true}) - interceptor.doneTailing = make(chan bool) - - go func() { - for line := range interceptor.tailer.Lines { - interceptor.streamTarget.Write([]byte(line.Text + "\n")) - } - close(interceptor.doneTailing) - }() - } - - return nil -} - -func (interceptor *outputInterceptor) StopInterceptingAndReturnOutput() (string, error) { - if !interceptor.intercepting { - return "", errors.New("Not intercepting output!") - } - - interceptor.redirectFile.Close() - output, err := ioutil.ReadFile(interceptor.redirectFile.Name()) - os.Remove(interceptor.redirectFile.Name()) - - interceptor.intercepting = false - - if interceptor.streamTarget != nil { - interceptor.tailer.Stop() - interceptor.tailer.Cleanup() - <-interceptor.doneTailing - interceptor.streamTarget.Sync() - } - - return string(output), err -} - -func (interceptor *outputInterceptor) StreamTo(out *os.File) { - interceptor.streamTarget = out -} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go b/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go deleted file mode 100644 index 40c790336cc..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/output_interceptor_win.go +++ /dev/null @@ -1,36 +0,0 @@ -// +build windows - -package remote - -import ( - "errors" - "os" -) - -func NewOutputInterceptor() OutputInterceptor { - return &outputInterceptor{} -} - -type outputInterceptor struct { - intercepting bool -} - -func (interceptor *outputInterceptor) StartInterceptingOutput() error { - if interceptor.intercepting { - return errors.New("Already intercepting output!") - } - interceptor.intercepting = true - - // not working on windows... - - return nil -} - -func (interceptor *outputInterceptor) StopInterceptingAndReturnOutput() (string, error) { - // not working on windows... - interceptor.intercepting = false - - return "", nil -} - -func (interceptor *outputInterceptor) StreamTo(*os.File) {} diff --git a/vendor/github.com/onsi/ginkgo/internal/remote/server.go b/vendor/github.com/onsi/ginkgo/internal/remote/server.go deleted file mode 100644 index 93e9dac057d..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/remote/server.go +++ /dev/null @@ -1,224 +0,0 @@ -/* - -The remote package provides the pieces to allow Ginkgo test suites to report to remote listeners. -This is used, primarily, to enable streaming parallel test output but has, in principal, broader applications (e.g. streaming test output to a browser). - -*/ - -package remote - -import ( - "encoding/json" - "io/ioutil" - "net" - "net/http" - "sync" - - "github.com/onsi/ginkgo/internal/spec_iterator" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/reporters" - "github.com/onsi/ginkgo/types" -) - -/* -Server spins up on an automatically selected port and listens for communication from the forwarding reporter. -It then forwards that communication to attached reporters. -*/ -type Server struct { - listener net.Listener - reporters []reporters.Reporter - alives []func() bool - lock *sync.Mutex - beforeSuiteData types.RemoteBeforeSuiteData - parallelTotal int - counter int -} - -//Create a new server, automatically selecting a port -func NewServer(parallelTotal int) (*Server, error) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return nil, err - } - return &Server{ - listener: listener, - lock: &sync.Mutex{}, - alives: make([]func() bool, parallelTotal), - beforeSuiteData: types.RemoteBeforeSuiteData{Data: nil, State: types.RemoteBeforeSuiteStatePending}, - parallelTotal: parallelTotal, - }, nil -} - -//Start the server. You don't need to `go s.Start()`, just `s.Start()` -func (server *Server) Start() { - httpServer := &http.Server{} - mux := http.NewServeMux() - httpServer.Handler = mux - - //streaming endpoints - mux.HandleFunc("/SpecSuiteWillBegin", server.specSuiteWillBegin) - mux.HandleFunc("/BeforeSuiteDidRun", server.beforeSuiteDidRun) - mux.HandleFunc("/AfterSuiteDidRun", server.afterSuiteDidRun) - mux.HandleFunc("/SpecWillRun", server.specWillRun) - mux.HandleFunc("/SpecDidComplete", server.specDidComplete) - mux.HandleFunc("/SpecSuiteDidEnd", server.specSuiteDidEnd) - - //synchronization endpoints - mux.HandleFunc("/BeforeSuiteState", server.handleBeforeSuiteState) - mux.HandleFunc("/RemoteAfterSuiteData", server.handleRemoteAfterSuiteData) - mux.HandleFunc("/counter", server.handleCounter) - mux.HandleFunc("/has-counter", server.handleHasCounter) //for backward compatibility - - go httpServer.Serve(server.listener) -} - -//Stop the server -func (server *Server) Close() { - server.listener.Close() -} - -//The address the server can be reached it. Pass this into the `ForwardingReporter`. -func (server *Server) Address() string { - return "http://" + server.listener.Addr().String() -} - -// -// Streaming Endpoints -// - -//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` -func (server *Server) readAll(request *http.Request) []byte { - defer request.Body.Close() - body, _ := ioutil.ReadAll(request.Body) - return body -} - -func (server *Server) RegisterReporters(reporters ...reporters.Reporter) { - server.reporters = reporters -} - -func (server *Server) specSuiteWillBegin(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - - var data struct { - Config config.GinkgoConfigType `json:"config"` - Summary *types.SuiteSummary `json:"suite-summary"` - } - - json.Unmarshal(body, &data) - - for _, reporter := range server.reporters { - reporter.SpecSuiteWillBegin(data.Config, data.Summary) - } -} - -func (server *Server) beforeSuiteDidRun(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - var setupSummary *types.SetupSummary - json.Unmarshal(body, &setupSummary) - - for _, reporter := range server.reporters { - reporter.BeforeSuiteDidRun(setupSummary) - } -} - -func (server *Server) afterSuiteDidRun(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - var setupSummary *types.SetupSummary - json.Unmarshal(body, &setupSummary) - - for _, reporter := range server.reporters { - reporter.AfterSuiteDidRun(setupSummary) - } -} - -func (server *Server) specWillRun(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - var specSummary *types.SpecSummary - json.Unmarshal(body, &specSummary) - - for _, reporter := range server.reporters { - reporter.SpecWillRun(specSummary) - } -} - -func (server *Server) specDidComplete(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - var specSummary *types.SpecSummary - json.Unmarshal(body, &specSummary) - - for _, reporter := range server.reporters { - reporter.SpecDidComplete(specSummary) - } -} - -func (server *Server) specSuiteDidEnd(writer http.ResponseWriter, request *http.Request) { - body := server.readAll(request) - var suiteSummary *types.SuiteSummary - json.Unmarshal(body, &suiteSummary) - - for _, reporter := range server.reporters { - reporter.SpecSuiteDidEnd(suiteSummary) - } -} - -// -// Synchronization Endpoints -// - -func (server *Server) RegisterAlive(node int, alive func() bool) { - server.lock.Lock() - defer server.lock.Unlock() - server.alives[node-1] = alive -} - -func (server *Server) nodeIsAlive(node int) bool { - server.lock.Lock() - defer server.lock.Unlock() - alive := server.alives[node-1] - if alive == nil { - return true - } - return alive() -} - -func (server *Server) handleBeforeSuiteState(writer http.ResponseWriter, request *http.Request) { - if request.Method == "POST" { - dec := json.NewDecoder(request.Body) - dec.Decode(&(server.beforeSuiteData)) - } else { - beforeSuiteData := server.beforeSuiteData - if beforeSuiteData.State == types.RemoteBeforeSuiteStatePending && !server.nodeIsAlive(1) { - beforeSuiteData.State = types.RemoteBeforeSuiteStateDisappeared - } - enc := json.NewEncoder(writer) - enc.Encode(beforeSuiteData) - } -} - -func (server *Server) handleRemoteAfterSuiteData(writer http.ResponseWriter, request *http.Request) { - afterSuiteData := types.RemoteAfterSuiteData{ - CanRun: true, - } - for i := 2; i <= server.parallelTotal; i++ { - afterSuiteData.CanRun = afterSuiteData.CanRun && !server.nodeIsAlive(i) - } - - enc := json.NewEncoder(writer) - enc.Encode(afterSuiteData) -} - -func (server *Server) handleCounter(writer http.ResponseWriter, request *http.Request) { - c := spec_iterator.Counter{} - server.lock.Lock() - c.Index = server.counter - server.counter++ - server.lock.Unlock() - - json.NewEncoder(writer).Encode(c) -} - -func (server *Server) handleHasCounter(writer http.ResponseWriter, request *http.Request) { - writer.Write([]byte("")) -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec/spec.go b/vendor/github.com/onsi/ginkgo/internal/spec/spec.go deleted file mode 100644 index 6eef40a0e0c..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec/spec.go +++ /dev/null @@ -1,247 +0,0 @@ -package spec - -import ( - "fmt" - "io" - "time" - - "sync" - - "github.com/onsi/ginkgo/internal/containernode" - "github.com/onsi/ginkgo/internal/leafnodes" - "github.com/onsi/ginkgo/types" -) - -type Spec struct { - subject leafnodes.SubjectNode - focused bool - announceProgress bool - - containers []*containernode.ContainerNode - - state types.SpecState - runTime time.Duration - startTime time.Time - failure types.SpecFailure - previousFailures bool - - stateMutex *sync.Mutex -} - -func New(subject leafnodes.SubjectNode, containers []*containernode.ContainerNode, announceProgress bool) *Spec { - spec := &Spec{ - subject: subject, - containers: containers, - focused: subject.Flag() == types.FlagTypeFocused, - announceProgress: announceProgress, - stateMutex: &sync.Mutex{}, - } - - spec.processFlag(subject.Flag()) - for i := len(containers) - 1; i >= 0; i-- { - spec.processFlag(containers[i].Flag()) - } - - return spec -} - -func (spec *Spec) processFlag(flag types.FlagType) { - if flag == types.FlagTypeFocused { - spec.focused = true - } else if flag == types.FlagTypePending { - spec.setState(types.SpecStatePending) - } -} - -func (spec *Spec) Skip() { - spec.setState(types.SpecStateSkipped) -} - -func (spec *Spec) Failed() bool { - return spec.getState() == types.SpecStateFailed || spec.getState() == types.SpecStatePanicked || spec.getState() == types.SpecStateTimedOut -} - -func (spec *Spec) Passed() bool { - return spec.getState() == types.SpecStatePassed -} - -func (spec *Spec) Flaked() bool { - return spec.getState() == types.SpecStatePassed && spec.previousFailures -} - -func (spec *Spec) Pending() bool { - return spec.getState() == types.SpecStatePending -} - -func (spec *Spec) Skipped() bool { - return spec.getState() == types.SpecStateSkipped -} - -func (spec *Spec) Focused() bool { - return spec.focused -} - -func (spec *Spec) IsMeasurement() bool { - return spec.subject.Type() == types.SpecComponentTypeMeasure -} - -func (spec *Spec) Summary(suiteID string) *types.SpecSummary { - componentTexts := make([]string, len(spec.containers)+1) - componentCodeLocations := make([]types.CodeLocation, len(spec.containers)+1) - - for i, container := range spec.containers { - componentTexts[i] = container.Text() - componentCodeLocations[i] = container.CodeLocation() - } - - componentTexts[len(spec.containers)] = spec.subject.Text() - componentCodeLocations[len(spec.containers)] = spec.subject.CodeLocation() - - runTime := spec.runTime - if runTime == 0 && !spec.startTime.IsZero() { - runTime = time.Since(spec.startTime) - } - - return &types.SpecSummary{ - IsMeasurement: spec.IsMeasurement(), - NumberOfSamples: spec.subject.Samples(), - ComponentTexts: componentTexts, - ComponentCodeLocations: componentCodeLocations, - State: spec.getState(), - RunTime: runTime, - Failure: spec.failure, - Measurements: spec.measurementsReport(), - SuiteID: suiteID, - } -} - -func (spec *Spec) ConcatenatedString() string { - s := "" - for _, container := range spec.containers { - s += container.Text() + " " - } - - return s + spec.subject.Text() -} - -func (spec *Spec) Run(writer io.Writer) { - if spec.getState() == types.SpecStateFailed { - spec.previousFailures = true - } - - spec.startTime = time.Now() - defer func() { - spec.runTime = time.Since(spec.startTime) - }() - - for sample := 0; sample < spec.subject.Samples(); sample++ { - spec.runSample(sample, writer) - - if spec.getState() != types.SpecStatePassed { - return - } - } -} - -func (spec *Spec) getState() types.SpecState { - spec.stateMutex.Lock() - defer spec.stateMutex.Unlock() - return spec.state -} - -func (spec *Spec) setState(state types.SpecState) { - spec.stateMutex.Lock() - defer spec.stateMutex.Unlock() - spec.state = state -} - -func (spec *Spec) runSample(sample int, writer io.Writer) { - spec.setState(types.SpecStatePassed) - spec.failure = types.SpecFailure{} - innerMostContainerIndexToUnwind := -1 - - defer func() { - for i := innerMostContainerIndexToUnwind; i >= 0; i-- { - container := spec.containers[i] - for _, justAfterEach := range container.SetupNodesOfType(types.SpecComponentTypeJustAfterEach) { - spec.announceSetupNode(writer, "JustAfterEach", container, justAfterEach) - justAfterEachState, justAfterEachFailure := justAfterEach.Run() - if justAfterEachState != types.SpecStatePassed && spec.state == types.SpecStatePassed { - spec.state = justAfterEachState - spec.failure = justAfterEachFailure - } - } - } - - for i := innerMostContainerIndexToUnwind; i >= 0; i-- { - container := spec.containers[i] - for _, afterEach := range container.SetupNodesOfType(types.SpecComponentTypeAfterEach) { - spec.announceSetupNode(writer, "AfterEach", container, afterEach) - afterEachState, afterEachFailure := afterEach.Run() - if afterEachState != types.SpecStatePassed && spec.getState() == types.SpecStatePassed { - spec.setState(afterEachState) - spec.failure = afterEachFailure - } - } - } - }() - - for i, container := range spec.containers { - innerMostContainerIndexToUnwind = i - for _, beforeEach := range container.SetupNodesOfType(types.SpecComponentTypeBeforeEach) { - spec.announceSetupNode(writer, "BeforeEach", container, beforeEach) - s, f := beforeEach.Run() - spec.failure = f - spec.setState(s) - if spec.getState() != types.SpecStatePassed { - return - } - } - } - - for _, container := range spec.containers { - for _, justBeforeEach := range container.SetupNodesOfType(types.SpecComponentTypeJustBeforeEach) { - spec.announceSetupNode(writer, "JustBeforeEach", container, justBeforeEach) - s, f := justBeforeEach.Run() - spec.failure = f - spec.setState(s) - if spec.getState() != types.SpecStatePassed { - return - } - } - } - - spec.announceSubject(writer, spec.subject) - s, f := spec.subject.Run() - spec.failure = f - spec.setState(s) -} - -func (spec *Spec) announceSetupNode(writer io.Writer, nodeType string, container *containernode.ContainerNode, setupNode leafnodes.BasicNode) { - if spec.announceProgress { - s := fmt.Sprintf("[%s] %s\n %s\n", nodeType, container.Text(), setupNode.CodeLocation().String()) - writer.Write([]byte(s)) - } -} - -func (spec *Spec) announceSubject(writer io.Writer, subject leafnodes.SubjectNode) { - if spec.announceProgress { - nodeType := "" - switch subject.Type() { - case types.SpecComponentTypeIt: - nodeType = "It" - case types.SpecComponentTypeMeasure: - nodeType = "Measure" - } - s := fmt.Sprintf("[%s] %s\n %s\n", nodeType, subject.Text(), subject.CodeLocation().String()) - writer.Write([]byte(s)) - } -} - -func (spec *Spec) measurementsReport() map[string]*types.SpecMeasurement { - if !spec.IsMeasurement() || spec.Failed() { - return map[string]*types.SpecMeasurement{} - } - - return spec.subject.(*leafnodes.MeasureNode).MeasurementsReport() -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec/specs.go b/vendor/github.com/onsi/ginkgo/internal/spec/specs.go deleted file mode 100644 index 0a24139fb1c..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec/specs.go +++ /dev/null @@ -1,144 +0,0 @@ -package spec - -import ( - "math/rand" - "regexp" - "sort" - "strings" -) - -type Specs struct { - specs []*Spec - names []string - - hasProgrammaticFocus bool - RegexScansFilePath bool -} - -func NewSpecs(specs []*Spec) *Specs { - names := make([]string, len(specs)) - for i, spec := range specs { - names[i] = spec.ConcatenatedString() - } - return &Specs{ - specs: specs, - names: names, - } -} - -func (e *Specs) Specs() []*Spec { - return e.specs -} - -func (e *Specs) HasProgrammaticFocus() bool { - return e.hasProgrammaticFocus -} - -func (e *Specs) Shuffle(r *rand.Rand) { - sort.Sort(e) - permutation := r.Perm(len(e.specs)) - shuffledSpecs := make([]*Spec, len(e.specs)) - names := make([]string, len(e.specs)) - for i, j := range permutation { - shuffledSpecs[i] = e.specs[j] - names[i] = e.names[j] - } - e.specs = shuffledSpecs - e.names = names -} - -func (e *Specs) ApplyFocus(description string, focus, skip []string) { - if len(focus)+len(skip) == 0 { - e.applyProgrammaticFocus() - } else { - e.applyRegExpFocusAndSkip(description, focus, skip) - } -} - -func (e *Specs) applyProgrammaticFocus() { - e.hasProgrammaticFocus = false - for _, spec := range e.specs { - if spec.Focused() && !spec.Pending() { - e.hasProgrammaticFocus = true - break - } - } - - if e.hasProgrammaticFocus { - for _, spec := range e.specs { - if !spec.Focused() { - spec.Skip() - } - } - } -} - -// toMatch returns a byte[] to be used by regex matchers. When adding new behaviours to the matching function, -// this is the place which we append to. -func (e *Specs) toMatch(description string, i int) []byte { - if i > len(e.names) { - return nil - } - if e.RegexScansFilePath { - return []byte( - description + " " + - e.names[i] + " " + - e.specs[i].subject.CodeLocation().FileName) - } else { - return []byte( - description + " " + - e.names[i]) - } -} - -func (e *Specs) applyRegExpFocusAndSkip(description string, focus, skip []string) { - var focusFilter, skipFilter *regexp.Regexp - if len(focus) > 0 { - focusFilter = regexp.MustCompile(strings.Join(focus, "|")) - } - if len(skip) > 0 { - skipFilter = regexp.MustCompile(strings.Join(skip, "|")) - } - - for i, spec := range e.specs { - matchesFocus := true - matchesSkip := false - - toMatch := e.toMatch(description, i) - - if focusFilter != nil { - matchesFocus = focusFilter.Match(toMatch) - } - - if skipFilter != nil { - matchesSkip = skipFilter.Match(toMatch) - } - - if !matchesFocus || matchesSkip { - spec.Skip() - } - } -} - -func (e *Specs) SkipMeasurements() { - for _, spec := range e.specs { - if spec.IsMeasurement() { - spec.Skip() - } - } -} - -//sort.Interface - -func (e *Specs) Len() int { - return len(e.specs) -} - -func (e *Specs) Less(i, j int) bool { - return e.names[i] < e.names[j] -} - -func (e *Specs) Swap(i, j int) { - e.names[i], e.names[j] = e.names[j], e.names[i] - e.specs[i], e.specs[j] = e.specs[j], e.specs[i] -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go deleted file mode 100644 index 82272554aa8..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/index_computer.go +++ /dev/null @@ -1,55 +0,0 @@ -package spec_iterator - -func ParallelizedIndexRange(length int, parallelTotal int, parallelNode int) (startIndex int, count int) { - if length == 0 { - return 0, 0 - } - - // We have more nodes than tests. Trivial case. - if parallelTotal >= length { - if parallelNode > length { - return 0, 0 - } else { - return parallelNode - 1, 1 - } - } - - // This is the minimum amount of tests that a node will be required to run - minTestsPerNode := length / parallelTotal - - // This is the maximum amount of tests that a node will be required to run - // The algorithm guarantees that this would be equal to at least the minimum amount - // and at most one more - maxTestsPerNode := minTestsPerNode - if length%parallelTotal != 0 { - maxTestsPerNode++ - } - - // Number of nodes that will have to run the maximum amount of tests per node - numMaxLoadNodes := length % parallelTotal - - // Number of nodes that precede the current node and will have to run the maximum amount of tests per node - var numPrecedingMaxLoadNodes int - if parallelNode > numMaxLoadNodes { - numPrecedingMaxLoadNodes = numMaxLoadNodes - } else { - numPrecedingMaxLoadNodes = parallelNode - 1 - } - - // Number of nodes that precede the current node and will have to run the minimum amount of tests per node - var numPrecedingMinLoadNodes int - if parallelNode <= numMaxLoadNodes { - numPrecedingMinLoadNodes = 0 - } else { - numPrecedingMinLoadNodes = parallelNode - numMaxLoadNodes - 1 - } - - // Evaluate the test start index and number of tests to run - startIndex = numPrecedingMaxLoadNodes*maxTestsPerNode + numPrecedingMinLoadNodes*minTestsPerNode - if parallelNode > numMaxLoadNodes { - count = minTestsPerNode - } else { - count = maxTestsPerNode - } - return -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go deleted file mode 100644 index 99f548bca43..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/parallel_spec_iterator.go +++ /dev/null @@ -1,59 +0,0 @@ -package spec_iterator - -import ( - "encoding/json" - "fmt" - "net/http" - - "github.com/onsi/ginkgo/internal/spec" -) - -type ParallelIterator struct { - specs []*spec.Spec - host string - client *http.Client -} - -func NewParallelIterator(specs []*spec.Spec, host string) *ParallelIterator { - return &ParallelIterator{ - specs: specs, - host: host, - client: &http.Client{}, - } -} - -func (s *ParallelIterator) Next() (*spec.Spec, error) { - resp, err := s.client.Get(s.host + "/counter") - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) - } - - var counter Counter - err = json.NewDecoder(resp.Body).Decode(&counter) - if err != nil { - return nil, err - } - - if counter.Index >= len(s.specs) { - return nil, ErrClosed - } - - return s.specs[counter.Index], nil -} - -func (s *ParallelIterator) NumberOfSpecsPriorToIteration() int { - return len(s.specs) -} - -func (s *ParallelIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { - return -1, false -} - -func (s *ParallelIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { - return -1, false -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go deleted file mode 100644 index a51c93b8b61..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/serial_spec_iterator.go +++ /dev/null @@ -1,45 +0,0 @@ -package spec_iterator - -import ( - "github.com/onsi/ginkgo/internal/spec" -) - -type SerialIterator struct { - specs []*spec.Spec - index int -} - -func NewSerialIterator(specs []*spec.Spec) *SerialIterator { - return &SerialIterator{ - specs: specs, - index: 0, - } -} - -func (s *SerialIterator) Next() (*spec.Spec, error) { - if s.index >= len(s.specs) { - return nil, ErrClosed - } - - spec := s.specs[s.index] - s.index += 1 - return spec, nil -} - -func (s *SerialIterator) NumberOfSpecsPriorToIteration() int { - return len(s.specs) -} - -func (s *SerialIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { - return len(s.specs), true -} - -func (s *SerialIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { - count := 0 - for _, s := range s.specs { - if !s.Skipped() && !s.Pending() { - count += 1 - } - } - return count, true -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go deleted file mode 100644 index ad4a3ea3c64..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/sharded_parallel_spec_iterator.go +++ /dev/null @@ -1,47 +0,0 @@ -package spec_iterator - -import "github.com/onsi/ginkgo/internal/spec" - -type ShardedParallelIterator struct { - specs []*spec.Spec - index int - maxIndex int -} - -func NewShardedParallelIterator(specs []*spec.Spec, total int, node int) *ShardedParallelIterator { - startIndex, count := ParallelizedIndexRange(len(specs), total, node) - - return &ShardedParallelIterator{ - specs: specs, - index: startIndex, - maxIndex: startIndex + count, - } -} - -func (s *ShardedParallelIterator) Next() (*spec.Spec, error) { - if s.index >= s.maxIndex { - return nil, ErrClosed - } - - spec := s.specs[s.index] - s.index += 1 - return spec, nil -} - -func (s *ShardedParallelIterator) NumberOfSpecsPriorToIteration() int { - return len(s.specs) -} - -func (s *ShardedParallelIterator) NumberOfSpecsToProcessIfKnown() (int, bool) { - return s.maxIndex - s.index, true -} - -func (s *ShardedParallelIterator) NumberOfSpecsThatWillBeRunIfKnown() (int, bool) { - count := 0 - for i := s.index; i < s.maxIndex; i += 1 { - if !s.specs[i].Skipped() && !s.specs[i].Pending() { - count += 1 - } - } - return count, true -} diff --git a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go b/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go deleted file mode 100644 index 74bffad6464..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/spec_iterator/spec_iterator.go +++ /dev/null @@ -1,20 +0,0 @@ -package spec_iterator - -import ( - "errors" - - "github.com/onsi/ginkgo/internal/spec" -) - -var ErrClosed = errors.New("no more specs to run") - -type SpecIterator interface { - Next() (*spec.Spec, error) - NumberOfSpecsPriorToIteration() int - NumberOfSpecsToProcessIfKnown() (int, bool) - NumberOfSpecsThatWillBeRunIfKnown() (int, bool) -} - -type Counter struct { - Index int `json:"index"` -} diff --git a/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go b/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go deleted file mode 100644 index 6739c3f605e..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/writer/fake_writer.go +++ /dev/null @@ -1,36 +0,0 @@ -package writer - -type FakeGinkgoWriter struct { - EventStream []string -} - -func NewFake() *FakeGinkgoWriter { - return &FakeGinkgoWriter{ - EventStream: []string{}, - } -} - -func (writer *FakeGinkgoWriter) AddEvent(event string) { - writer.EventStream = append(writer.EventStream, event) -} - -func (writer *FakeGinkgoWriter) Truncate() { - writer.EventStream = append(writer.EventStream, "TRUNCATE") -} - -func (writer *FakeGinkgoWriter) DumpOut() { - writer.EventStream = append(writer.EventStream, "DUMP") -} - -func (writer *FakeGinkgoWriter) DumpOutWithHeader(header string) { - writer.EventStream = append(writer.EventStream, "DUMP_WITH_HEADER: "+header) -} - -func (writer *FakeGinkgoWriter) Bytes() []byte { - writer.EventStream = append(writer.EventStream, "BYTES") - return nil -} - -func (writer *FakeGinkgoWriter) Write(data []byte) (n int, err error) { - return 0, nil -} diff --git a/vendor/github.com/onsi/ginkgo/internal/writer/writer.go b/vendor/github.com/onsi/ginkgo/internal/writer/writer.go deleted file mode 100644 index 98eca3bdd25..00000000000 --- a/vendor/github.com/onsi/ginkgo/internal/writer/writer.go +++ /dev/null @@ -1,89 +0,0 @@ -package writer - -import ( - "bytes" - "io" - "sync" -) - -type WriterInterface interface { - io.Writer - - Truncate() - DumpOut() - DumpOutWithHeader(header string) - Bytes() []byte -} - -type Writer struct { - buffer *bytes.Buffer - outWriter io.Writer - lock *sync.Mutex - stream bool - redirector io.Writer -} - -func New(outWriter io.Writer) *Writer { - return &Writer{ - buffer: &bytes.Buffer{}, - lock: &sync.Mutex{}, - outWriter: outWriter, - stream: true, - } -} - -func (w *Writer) AndRedirectTo(writer io.Writer) { - w.redirector = writer -} - -func (w *Writer) SetStream(stream bool) { - w.lock.Lock() - defer w.lock.Unlock() - w.stream = stream -} - -func (w *Writer) Write(b []byte) (n int, err error) { - w.lock.Lock() - defer w.lock.Unlock() - - n, err = w.buffer.Write(b) - if w.redirector != nil { - w.redirector.Write(b) - } - if w.stream { - return w.outWriter.Write(b) - } - return n, err -} - -func (w *Writer) Truncate() { - w.lock.Lock() - defer w.lock.Unlock() - w.buffer.Reset() -} - -func (w *Writer) DumpOut() { - w.lock.Lock() - defer w.lock.Unlock() - if !w.stream { - w.buffer.WriteTo(w.outWriter) - } -} - -func (w *Writer) Bytes() []byte { - w.lock.Lock() - defer w.lock.Unlock() - b := w.buffer.Bytes() - copied := make([]byte, len(b)) - copy(copied, b) - return copied -} - -func (w *Writer) DumpOutWithHeader(header string) { - w.lock.Lock() - defer w.lock.Unlock() - if !w.stream && w.buffer.Len() > 0 { - w.outWriter.Write([]byte(header)) - w.buffer.WriteTo(w.outWriter) - } -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go deleted file mode 100644 index f0c9f61410b..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/default_reporter.go +++ /dev/null @@ -1,87 +0,0 @@ -/* -Ginkgo's Default Reporter - -A number of command line flags are available to tweak Ginkgo's default output. - -These are documented [here](http://onsi.github.io/ginkgo/#running_tests) -*/ -package reporters - -import ( - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/reporters/stenographer" - "github.com/onsi/ginkgo/types" -) - -type DefaultReporter struct { - config config.DefaultReporterConfigType - stenographer stenographer.Stenographer - specSummaries []*types.SpecSummary -} - -func NewDefaultReporter(config config.DefaultReporterConfigType, stenographer stenographer.Stenographer) *DefaultReporter { - return &DefaultReporter{ - config: config, - stenographer: stenographer, - } -} - -func (reporter *DefaultReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { - reporter.stenographer.AnnounceSuite(summary.SuiteDescription, config.RandomSeed, config.RandomizeAllSpecs, reporter.config.Succinct) - if config.ParallelTotal > 1 { - reporter.stenographer.AnnounceParallelRun(config.ParallelNode, config.ParallelTotal, reporter.config.Succinct) - } else { - reporter.stenographer.AnnounceNumberOfSpecs(summary.NumberOfSpecsThatWillBeRun, summary.NumberOfTotalSpecs, reporter.config.Succinct) - } -} - -func (reporter *DefaultReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - if setupSummary.State != types.SpecStatePassed { - reporter.stenographer.AnnounceBeforeSuiteFailure(setupSummary, reporter.config.Succinct, reporter.config.FullTrace) - } -} - -func (reporter *DefaultReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - if setupSummary.State != types.SpecStatePassed { - reporter.stenographer.AnnounceAfterSuiteFailure(setupSummary, reporter.config.Succinct, reporter.config.FullTrace) - } -} - -func (reporter *DefaultReporter) SpecWillRun(specSummary *types.SpecSummary) { - if reporter.config.Verbose && !reporter.config.Succinct && specSummary.State != types.SpecStatePending && specSummary.State != types.SpecStateSkipped { - reporter.stenographer.AnnounceSpecWillRun(specSummary) - } -} - -func (reporter *DefaultReporter) SpecDidComplete(specSummary *types.SpecSummary) { - switch specSummary.State { - case types.SpecStatePassed: - if specSummary.IsMeasurement { - reporter.stenographer.AnnounceSuccessfulMeasurement(specSummary, reporter.config.Succinct) - } else if specSummary.RunTime.Seconds() >= reporter.config.SlowSpecThreshold { - reporter.stenographer.AnnounceSuccessfulSlowSpec(specSummary, reporter.config.Succinct) - } else { - reporter.stenographer.AnnounceSuccessfulSpec(specSummary) - if reporter.config.ReportPassed { - reporter.stenographer.AnnounceCapturedOutput(specSummary.CapturedOutput) - } - } - case types.SpecStatePending: - reporter.stenographer.AnnouncePendingSpec(specSummary, reporter.config.NoisyPendings && !reporter.config.Succinct) - case types.SpecStateSkipped: - reporter.stenographer.AnnounceSkippedSpec(specSummary, reporter.config.Succinct || !reporter.config.NoisySkippings, reporter.config.FullTrace) - case types.SpecStateTimedOut: - reporter.stenographer.AnnounceSpecTimedOut(specSummary, reporter.config.Succinct, reporter.config.FullTrace) - case types.SpecStatePanicked: - reporter.stenographer.AnnounceSpecPanicked(specSummary, reporter.config.Succinct, reporter.config.FullTrace) - case types.SpecStateFailed: - reporter.stenographer.AnnounceSpecFailed(specSummary, reporter.config.Succinct, reporter.config.FullTrace) - } - - reporter.specSummaries = append(reporter.specSummaries, specSummary) -} - -func (reporter *DefaultReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { - reporter.stenographer.SummarizeFailures(reporter.specSummaries) - reporter.stenographer.AnnounceSpecRunCompletion(summary, reporter.config.Succinct) -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go deleted file mode 100644 index 27db4794908..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/fake_reporter.go +++ /dev/null @@ -1,59 +0,0 @@ -package reporters - -import ( - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/types" -) - -//FakeReporter is useful for testing purposes -type FakeReporter struct { - Config config.GinkgoConfigType - - BeginSummary *types.SuiteSummary - BeforeSuiteSummary *types.SetupSummary - SpecWillRunSummaries []*types.SpecSummary - SpecSummaries []*types.SpecSummary - AfterSuiteSummary *types.SetupSummary - EndSummary *types.SuiteSummary - - SpecWillRunStub func(specSummary *types.SpecSummary) - SpecDidCompleteStub func(specSummary *types.SpecSummary) -} - -func NewFakeReporter() *FakeReporter { - return &FakeReporter{ - SpecWillRunSummaries: make([]*types.SpecSummary, 0), - SpecSummaries: make([]*types.SpecSummary, 0), - } -} - -func (fakeR *FakeReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { - fakeR.Config = config - fakeR.BeginSummary = summary -} - -func (fakeR *FakeReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - fakeR.BeforeSuiteSummary = setupSummary -} - -func (fakeR *FakeReporter) SpecWillRun(specSummary *types.SpecSummary) { - if fakeR.SpecWillRunStub != nil { - fakeR.SpecWillRunStub(specSummary) - } - fakeR.SpecWillRunSummaries = append(fakeR.SpecWillRunSummaries, specSummary) -} - -func (fakeR *FakeReporter) SpecDidComplete(specSummary *types.SpecSummary) { - if fakeR.SpecDidCompleteStub != nil { - fakeR.SpecDidCompleteStub(specSummary) - } - fakeR.SpecSummaries = append(fakeR.SpecSummaries, specSummary) -} - -func (fakeR *FakeReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - fakeR.AfterSuiteSummary = setupSummary -} - -func (fakeR *FakeReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { - fakeR.EndSummary = summary -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go deleted file mode 100644 index 01ddca6e1df..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/junit_reporter.go +++ /dev/null @@ -1,178 +0,0 @@ -/* - -JUnit XML Reporter for Ginkgo - -For usage instructions: http://onsi.github.io/ginkgo/#generating_junit_xml_output - -*/ - -package reporters - -import ( - "encoding/xml" - "fmt" - "math" - "os" - "path/filepath" - "strings" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/types" -) - -type JUnitTestSuite struct { - XMLName xml.Name `xml:"testsuite"` - TestCases []JUnitTestCase `xml:"testcase"` - Name string `xml:"name,attr"` - Tests int `xml:"tests,attr"` - Failures int `xml:"failures,attr"` - Errors int `xml:"errors,attr"` - Time float64 `xml:"time,attr"` -} - -type JUnitTestCase struct { - Name string `xml:"name,attr"` - ClassName string `xml:"classname,attr"` - FailureMessage *JUnitFailureMessage `xml:"failure,omitempty"` - Skipped *JUnitSkipped `xml:"skipped,omitempty"` - Time float64 `xml:"time,attr"` - SystemOut string `xml:"system-out,omitempty"` -} - -type JUnitFailureMessage struct { - Type string `xml:"type,attr"` - Message string `xml:",chardata"` -} - -type JUnitSkipped struct { - Message string `xml:",chardata"` -} - -type JUnitReporter struct { - suite JUnitTestSuite - filename string - testSuiteName string - ReporterConfig config.DefaultReporterConfigType -} - -//NewJUnitReporter creates a new JUnit XML reporter. The XML will be stored in the passed in filename. -func NewJUnitReporter(filename string) *JUnitReporter { - return &JUnitReporter{ - filename: filename, - } -} - -func (reporter *JUnitReporter) SpecSuiteWillBegin(ginkgoConfig config.GinkgoConfigType, summary *types.SuiteSummary) { - reporter.suite = JUnitTestSuite{ - Name: summary.SuiteDescription, - TestCases: []JUnitTestCase{}, - } - reporter.testSuiteName = summary.SuiteDescription - reporter.ReporterConfig = config.DefaultReporterConfig -} - -func (reporter *JUnitReporter) SpecWillRun(specSummary *types.SpecSummary) { -} - -func (reporter *JUnitReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - reporter.handleSetupSummary("BeforeSuite", setupSummary) -} - -func (reporter *JUnitReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - reporter.handleSetupSummary("AfterSuite", setupSummary) -} - -func failureMessage(failure types.SpecFailure) string { - return fmt.Sprintf("%s\n%s\n%s", failure.ComponentCodeLocation.String(), failure.Message, failure.Location.String()) -} - -func (reporter *JUnitReporter) handleSetupSummary(name string, setupSummary *types.SetupSummary) { - if setupSummary.State != types.SpecStatePassed { - testCase := JUnitTestCase{ - Name: name, - ClassName: reporter.testSuiteName, - } - - testCase.FailureMessage = &JUnitFailureMessage{ - Type: reporter.failureTypeForState(setupSummary.State), - Message: failureMessage(setupSummary.Failure), - } - testCase.SystemOut = setupSummary.CapturedOutput - testCase.Time = setupSummary.RunTime.Seconds() - reporter.suite.TestCases = append(reporter.suite.TestCases, testCase) - } -} - -func (reporter *JUnitReporter) SpecDidComplete(specSummary *types.SpecSummary) { - testCase := JUnitTestCase{ - Name: strings.Join(specSummary.ComponentTexts[1:], " "), - ClassName: reporter.testSuiteName, - } - if reporter.ReporterConfig.ReportPassed && specSummary.State == types.SpecStatePassed { - testCase.SystemOut = specSummary.CapturedOutput - } - if specSummary.State == types.SpecStateFailed || specSummary.State == types.SpecStateTimedOut || specSummary.State == types.SpecStatePanicked { - testCase.FailureMessage = &JUnitFailureMessage{ - Type: reporter.failureTypeForState(specSummary.State), - Message: failureMessage(specSummary.Failure), - } - if specSummary.State == types.SpecStatePanicked { - testCase.FailureMessage.Message += fmt.Sprintf("\n\nPanic: %s\n\nFull stack:\n%s", - specSummary.Failure.ForwardedPanic, - specSummary.Failure.Location.FullStackTrace) - } - testCase.SystemOut = specSummary.CapturedOutput - } - if specSummary.State == types.SpecStateSkipped || specSummary.State == types.SpecStatePending { - testCase.Skipped = &JUnitSkipped{} - if specSummary.Failure.Message != "" { - testCase.Skipped.Message = failureMessage(specSummary.Failure) - } - } - testCase.Time = specSummary.RunTime.Seconds() - reporter.suite.TestCases = append(reporter.suite.TestCases, testCase) -} - -func (reporter *JUnitReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { - reporter.suite.Tests = summary.NumberOfSpecsThatWillBeRun - reporter.suite.Time = math.Trunc(summary.RunTime.Seconds()*1000) / 1000 - reporter.suite.Failures = summary.NumberOfFailedSpecs - reporter.suite.Errors = 0 - if reporter.ReporterConfig.ReportFile != "" { - reporter.filename = reporter.ReporterConfig.ReportFile - fmt.Printf("\nJUnit path was configured: %s\n", reporter.filename) - } - filePath, _ := filepath.Abs(reporter.filename) - dirPath := filepath.Dir(filePath) - err := os.MkdirAll(dirPath, os.ModePerm) - if err != nil { - fmt.Printf("\nFailed to create JUnit directory: %s\n\t%s", filePath, err.Error()) - } - file, err := os.Create(filePath) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to create JUnit report file: %s\n\t%s", filePath, err.Error()) - } - defer file.Close() - file.WriteString(xml.Header) - encoder := xml.NewEncoder(file) - encoder.Indent(" ", " ") - err = encoder.Encode(reporter.suite) - if err == nil { - fmt.Fprintf(os.Stdout, "\nJUnit report was created: %s\n", filePath) - } else { - fmt.Fprintf(os.Stderr,"\nFailed to generate JUnit report data:\n\t%s", err.Error()) - } -} - -func (reporter *JUnitReporter) failureTypeForState(state types.SpecState) string { - switch state { - case types.SpecStateFailed: - return "Failure" - case types.SpecStateTimedOut: - return "Timeout" - case types.SpecStatePanicked: - return "Panic" - default: - return "" - } -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/reporter.go b/vendor/github.com/onsi/ginkgo/reporters/reporter.go deleted file mode 100644 index 348b9dfce1f..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/reporter.go +++ /dev/null @@ -1,15 +0,0 @@ -package reporters - -import ( - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/types" -) - -type Reporter interface { - SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) - BeforeSuiteDidRun(setupSummary *types.SetupSummary) - SpecWillRun(specSummary *types.SpecSummary) - SpecDidComplete(specSummary *types.SpecSummary) - AfterSuiteDidRun(setupSummary *types.SetupSummary) - SpecSuiteDidEnd(summary *types.SuiteSummary) -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go deleted file mode 100644 index 45b8f88690e..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/console_logging.go +++ /dev/null @@ -1,64 +0,0 @@ -package stenographer - -import ( - "fmt" - "strings" -) - -func (s *consoleStenographer) colorize(colorCode string, format string, args ...interface{}) string { - var out string - - if len(args) > 0 { - out = fmt.Sprintf(format, args...) - } else { - out = format - } - - if s.color { - return fmt.Sprintf("%s%s%s", colorCode, out, defaultStyle) - } else { - return out - } -} - -func (s *consoleStenographer) printBanner(text string, bannerCharacter string) { - fmt.Fprintln(s.w, text) - fmt.Fprintln(s.w, strings.Repeat(bannerCharacter, len(text))) -} - -func (s *consoleStenographer) printNewLine() { - fmt.Fprintln(s.w, "") -} - -func (s *consoleStenographer) printDelimiter() { - fmt.Fprintln(s.w, s.colorize(grayColor, "%s", strings.Repeat("-", 30))) -} - -func (s *consoleStenographer) print(indentation int, format string, args ...interface{}) { - fmt.Fprint(s.w, s.indent(indentation, format, args...)) -} - -func (s *consoleStenographer) println(indentation int, format string, args ...interface{}) { - fmt.Fprintln(s.w, s.indent(indentation, format, args...)) -} - -func (s *consoleStenographer) indent(indentation int, format string, args ...interface{}) string { - var text string - - if len(args) > 0 { - text = fmt.Sprintf(format, args...) - } else { - text = format - } - - stringArray := strings.Split(text, "\n") - padding := "" - if indentation >= 0 { - padding = strings.Repeat(" ", indentation) - } - for i, s := range stringArray { - stringArray[i] = fmt.Sprintf("%s%s", padding, s) - } - - return strings.Join(stringArray, "\n") -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go deleted file mode 100644 index 1aa5b9db0fe..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/fake_stenographer.go +++ /dev/null @@ -1,142 +0,0 @@ -package stenographer - -import ( - "sync" - - "github.com/onsi/ginkgo/types" -) - -func NewFakeStenographerCall(method string, args ...interface{}) FakeStenographerCall { - return FakeStenographerCall{ - Method: method, - Args: args, - } -} - -type FakeStenographer struct { - calls []FakeStenographerCall - lock *sync.Mutex -} - -type FakeStenographerCall struct { - Method string - Args []interface{} -} - -func NewFakeStenographer() *FakeStenographer { - stenographer := &FakeStenographer{ - lock: &sync.Mutex{}, - } - stenographer.Reset() - return stenographer -} - -func (stenographer *FakeStenographer) Calls() []FakeStenographerCall { - stenographer.lock.Lock() - defer stenographer.lock.Unlock() - - return stenographer.calls -} - -func (stenographer *FakeStenographer) Reset() { - stenographer.lock.Lock() - defer stenographer.lock.Unlock() - - stenographer.calls = make([]FakeStenographerCall, 0) -} - -func (stenographer *FakeStenographer) CallsTo(method string) []FakeStenographerCall { - stenographer.lock.Lock() - defer stenographer.lock.Unlock() - - results := make([]FakeStenographerCall, 0) - for _, call := range stenographer.calls { - if call.Method == method { - results = append(results, call) - } - } - - return results -} - -func (stenographer *FakeStenographer) registerCall(method string, args ...interface{}) { - stenographer.lock.Lock() - defer stenographer.lock.Unlock() - - stenographer.calls = append(stenographer.calls, NewFakeStenographerCall(method, args...)) -} - -func (stenographer *FakeStenographer) AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) { - stenographer.registerCall("AnnounceSuite", description, randomSeed, randomizingAll, succinct) -} - -func (stenographer *FakeStenographer) AnnounceAggregatedParallelRun(nodes int, succinct bool) { - stenographer.registerCall("AnnounceAggregatedParallelRun", nodes, succinct) -} - -func (stenographer *FakeStenographer) AnnounceParallelRun(node int, nodes int, succinct bool) { - stenographer.registerCall("AnnounceParallelRun", node, nodes, succinct) -} - -func (stenographer *FakeStenographer) AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) { - stenographer.registerCall("AnnounceNumberOfSpecs", specsToRun, total, succinct) -} - -func (stenographer *FakeStenographer) AnnounceTotalNumberOfSpecs(total int, succinct bool) { - stenographer.registerCall("AnnounceTotalNumberOfSpecs", total, succinct) -} - -func (stenographer *FakeStenographer) AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) { - stenographer.registerCall("AnnounceSpecRunCompletion", summary, succinct) -} - -func (stenographer *FakeStenographer) AnnounceSpecWillRun(spec *types.SpecSummary) { - stenographer.registerCall("AnnounceSpecWillRun", spec) -} - -func (stenographer *FakeStenographer) AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceBeforeSuiteFailure", summary, succinct, fullTrace) -} - -func (stenographer *FakeStenographer) AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceAfterSuiteFailure", summary, succinct, fullTrace) -} -func (stenographer *FakeStenographer) AnnounceCapturedOutput(output string) { - stenographer.registerCall("AnnounceCapturedOutput", output) -} - -func (stenographer *FakeStenographer) AnnounceSuccessfulSpec(spec *types.SpecSummary) { - stenographer.registerCall("AnnounceSuccessfulSpec", spec) -} - -func (stenographer *FakeStenographer) AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) { - stenographer.registerCall("AnnounceSuccessfulSlowSpec", spec, succinct) -} - -func (stenographer *FakeStenographer) AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) { - stenographer.registerCall("AnnounceSuccessfulMeasurement", spec, succinct) -} - -func (stenographer *FakeStenographer) AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) { - stenographer.registerCall("AnnouncePendingSpec", spec, noisy) -} - -func (stenographer *FakeStenographer) AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceSkippedSpec", spec, succinct, fullTrace) -} - -func (stenographer *FakeStenographer) AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceSpecTimedOut", spec, succinct, fullTrace) -} - -func (stenographer *FakeStenographer) AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceSpecPanicked", spec, succinct, fullTrace) -} - -func (stenographer *FakeStenographer) AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) { - stenographer.registerCall("AnnounceSpecFailed", spec, succinct, fullTrace) -} - -func (stenographer *FakeStenographer) SummarizeFailures(summaries []*types.SpecSummary) { - stenographer.registerCall("SummarizeFailures", summaries) -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go deleted file mode 100644 index 638d6fbb1a9..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/stenographer.go +++ /dev/null @@ -1,572 +0,0 @@ -/* -The stenographer is used by Ginkgo's reporters to generate output. - -Move along, nothing to see here. -*/ - -package stenographer - -import ( - "fmt" - "io" - "runtime" - "strings" - - "github.com/onsi/ginkgo/types" -) - -const defaultStyle = "\x1b[0m" -const boldStyle = "\x1b[1m" -const redColor = "\x1b[91m" -const greenColor = "\x1b[32m" -const yellowColor = "\x1b[33m" -const cyanColor = "\x1b[36m" -const grayColor = "\x1b[90m" -const lightGrayColor = "\x1b[37m" - -type cursorStateType int - -const ( - cursorStateTop cursorStateType = iota - cursorStateStreaming - cursorStateMidBlock - cursorStateEndBlock -) - -type Stenographer interface { - AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) - AnnounceAggregatedParallelRun(nodes int, succinct bool) - AnnounceParallelRun(node int, nodes int, succinct bool) - AnnounceTotalNumberOfSpecs(total int, succinct bool) - AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) - AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) - - AnnounceSpecWillRun(spec *types.SpecSummary) - AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) - AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) - - AnnounceCapturedOutput(output string) - - AnnounceSuccessfulSpec(spec *types.SpecSummary) - AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) - AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) - - AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) - AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) - - AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) - AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) - AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) - - SummarizeFailures(summaries []*types.SpecSummary) -} - -func New(color bool, enableFlakes bool, writer io.Writer) Stenographer { - denoter := "•" - if runtime.GOOS == "windows" { - denoter = "+" - } - return &consoleStenographer{ - color: color, - denoter: denoter, - cursorState: cursorStateTop, - enableFlakes: enableFlakes, - w: writer, - } -} - -type consoleStenographer struct { - color bool - denoter string - cursorState cursorStateType - enableFlakes bool - w io.Writer -} - -var alternatingColors = []string{defaultStyle, grayColor} - -func (s *consoleStenographer) AnnounceSuite(description string, randomSeed int64, randomizingAll bool, succinct bool) { - if succinct { - s.print(0, "[%d] %s ", randomSeed, s.colorize(boldStyle, description)) - return - } - s.printBanner(fmt.Sprintf("Running Suite: %s", description), "=") - s.print(0, "Random Seed: %s", s.colorize(boldStyle, "%d", randomSeed)) - if randomizingAll { - s.print(0, " - Will randomize all specs") - } - s.printNewLine() -} - -func (s *consoleStenographer) AnnounceParallelRun(node int, nodes int, succinct bool) { - if succinct { - s.print(0, "- node #%d ", node) - return - } - s.println(0, - "Parallel test node %s/%s.", - s.colorize(boldStyle, "%d", node), - s.colorize(boldStyle, "%d", nodes), - ) - s.printNewLine() -} - -func (s *consoleStenographer) AnnounceAggregatedParallelRun(nodes int, succinct bool) { - if succinct { - s.print(0, "- %d nodes ", nodes) - return - } - s.println(0, - "Running in parallel across %s nodes", - s.colorize(boldStyle, "%d", nodes), - ) - s.printNewLine() -} - -func (s *consoleStenographer) AnnounceNumberOfSpecs(specsToRun int, total int, succinct bool) { - if succinct { - s.print(0, "- %d/%d specs ", specsToRun, total) - s.stream() - return - } - s.println(0, - "Will run %s of %s specs", - s.colorize(boldStyle, "%d", specsToRun), - s.colorize(boldStyle, "%d", total), - ) - - s.printNewLine() -} - -func (s *consoleStenographer) AnnounceTotalNumberOfSpecs(total int, succinct bool) { - if succinct { - s.print(0, "- %d specs ", total) - s.stream() - return - } - s.println(0, - "Will run %s specs", - s.colorize(boldStyle, "%d", total), - ) - - s.printNewLine() -} - -func (s *consoleStenographer) AnnounceSpecRunCompletion(summary *types.SuiteSummary, succinct bool) { - if succinct && summary.SuiteSucceeded { - s.print(0, " %s %s ", s.colorize(greenColor, "SUCCESS!"), summary.RunTime) - return - } - s.printNewLine() - color := greenColor - if !summary.SuiteSucceeded { - color = redColor - } - s.println(0, s.colorize(boldStyle+color, "Ran %d of %d Specs in %.3f seconds", summary.NumberOfSpecsThatWillBeRun, summary.NumberOfTotalSpecs, summary.RunTime.Seconds())) - - status := "" - if summary.SuiteSucceeded { - status = s.colorize(boldStyle+greenColor, "SUCCESS!") - } else { - status = s.colorize(boldStyle+redColor, "FAIL!") - } - - flakes := "" - if s.enableFlakes { - flakes = " | " + s.colorize(yellowColor+boldStyle, "%d Flaked", summary.NumberOfFlakedSpecs) - } - - s.print(0, - "%s -- %s | %s | %s | %s\n", - status, - s.colorize(greenColor+boldStyle, "%d Passed", summary.NumberOfPassedSpecs), - s.colorize(redColor+boldStyle, "%d Failed", summary.NumberOfFailedSpecs)+flakes, - s.colorize(yellowColor+boldStyle, "%d Pending", summary.NumberOfPendingSpecs), - s.colorize(cyanColor+boldStyle, "%d Skipped", summary.NumberOfSkippedSpecs), - ) -} - -func (s *consoleStenographer) AnnounceSpecWillRun(spec *types.SpecSummary) { - s.startBlock() - for i, text := range spec.ComponentTexts[1 : len(spec.ComponentTexts)-1] { - s.print(0, s.colorize(alternatingColors[i%2], text)+" ") - } - - indentation := 0 - if len(spec.ComponentTexts) > 2 { - indentation = 1 - s.printNewLine() - } - index := len(spec.ComponentTexts) - 1 - s.print(indentation, s.colorize(boldStyle, spec.ComponentTexts[index])) - s.printNewLine() - s.print(indentation, s.colorize(lightGrayColor, spec.ComponentCodeLocations[index].String())) - s.printNewLine() - s.midBlock() -} - -func (s *consoleStenographer) AnnounceBeforeSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { - s.announceSetupFailure("BeforeSuite", summary, succinct, fullTrace) -} - -func (s *consoleStenographer) AnnounceAfterSuiteFailure(summary *types.SetupSummary, succinct bool, fullTrace bool) { - s.announceSetupFailure("AfterSuite", summary, succinct, fullTrace) -} - -func (s *consoleStenographer) announceSetupFailure(name string, summary *types.SetupSummary, succinct bool, fullTrace bool) { - s.startBlock() - var message string - switch summary.State { - case types.SpecStateFailed: - message = "Failure" - case types.SpecStatePanicked: - message = "Panic" - case types.SpecStateTimedOut: - message = "Timeout" - } - - s.println(0, s.colorize(redColor+boldStyle, "%s [%.3f seconds]", message, summary.RunTime.Seconds())) - - indentation := s.printCodeLocationBlock([]string{name}, []types.CodeLocation{summary.CodeLocation}, summary.ComponentType, 0, summary.State, true) - - s.printNewLine() - s.printFailure(indentation, summary.State, summary.Failure, fullTrace) - - s.endBlock() -} - -func (s *consoleStenographer) AnnounceCapturedOutput(output string) { - if output == "" { - return - } - - s.startBlock() - s.println(0, output) - s.midBlock() -} - -func (s *consoleStenographer) AnnounceSuccessfulSpec(spec *types.SpecSummary) { - s.print(0, s.colorize(greenColor, s.denoter)) - s.stream() -} - -func (s *consoleStenographer) AnnounceSuccessfulSlowSpec(spec *types.SpecSummary, succinct bool) { - s.printBlockWithMessage( - s.colorize(greenColor, "%s [SLOW TEST:%.3f seconds]", s.denoter, spec.RunTime.Seconds()), - "", - spec, - succinct, - ) -} - -func (s *consoleStenographer) AnnounceSuccessfulMeasurement(spec *types.SpecSummary, succinct bool) { - s.printBlockWithMessage( - s.colorize(greenColor, "%s [MEASUREMENT]", s.denoter), - s.measurementReport(spec, succinct), - spec, - succinct, - ) -} - -func (s *consoleStenographer) AnnouncePendingSpec(spec *types.SpecSummary, noisy bool) { - if noisy { - s.printBlockWithMessage( - s.colorize(yellowColor, "P [PENDING]"), - "", - spec, - false, - ) - } else { - s.print(0, s.colorize(yellowColor, "P")) - s.stream() - } -} - -func (s *consoleStenographer) AnnounceSkippedSpec(spec *types.SpecSummary, succinct bool, fullTrace bool) { - // Skips at runtime will have a non-empty spec.Failure. All others should be succinct. - if succinct || spec.Failure == (types.SpecFailure{}) { - s.print(0, s.colorize(cyanColor, "S")) - s.stream() - } else { - s.startBlock() - s.println(0, s.colorize(cyanColor+boldStyle, "S [SKIPPING]%s [%.3f seconds]", s.failureContext(spec.Failure.ComponentType), spec.RunTime.Seconds())) - - indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, spec.Failure.ComponentType, spec.Failure.ComponentIndex, spec.State, succinct) - - s.printNewLine() - s.printSkip(indentation, spec.Failure) - s.endBlock() - } -} - -func (s *consoleStenographer) AnnounceSpecTimedOut(spec *types.SpecSummary, succinct bool, fullTrace bool) { - s.printSpecFailure(fmt.Sprintf("%s... Timeout", s.denoter), spec, succinct, fullTrace) -} - -func (s *consoleStenographer) AnnounceSpecPanicked(spec *types.SpecSummary, succinct bool, fullTrace bool) { - s.printSpecFailure(fmt.Sprintf("%s! Panic", s.denoter), spec, succinct, fullTrace) -} - -func (s *consoleStenographer) AnnounceSpecFailed(spec *types.SpecSummary, succinct bool, fullTrace bool) { - s.printSpecFailure(fmt.Sprintf("%s Failure", s.denoter), spec, succinct, fullTrace) -} - -func (s *consoleStenographer) SummarizeFailures(summaries []*types.SpecSummary) { - failingSpecs := []*types.SpecSummary{} - - for _, summary := range summaries { - if summary.HasFailureState() { - failingSpecs = append(failingSpecs, summary) - } - } - - if len(failingSpecs) == 0 { - return - } - - s.printNewLine() - s.printNewLine() - plural := "s" - if len(failingSpecs) == 1 { - plural = "" - } - s.println(0, s.colorize(redColor+boldStyle, "Summarizing %d Failure%s:", len(failingSpecs), plural)) - for _, summary := range failingSpecs { - s.printNewLine() - if summary.HasFailureState() { - if summary.TimedOut() { - s.print(0, s.colorize(redColor+boldStyle, "[Timeout...] ")) - } else if summary.Panicked() { - s.print(0, s.colorize(redColor+boldStyle, "[Panic!] ")) - } else if summary.Failed() { - s.print(0, s.colorize(redColor+boldStyle, "[Fail] ")) - } - s.printSpecContext(summary.ComponentTexts, summary.ComponentCodeLocations, summary.Failure.ComponentType, summary.Failure.ComponentIndex, summary.State, true) - s.printNewLine() - s.println(0, s.colorize(lightGrayColor, summary.Failure.Location.String())) - } - } -} - -func (s *consoleStenographer) startBlock() { - if s.cursorState == cursorStateStreaming { - s.printNewLine() - s.printDelimiter() - } else if s.cursorState == cursorStateMidBlock { - s.printNewLine() - } -} - -func (s *consoleStenographer) midBlock() { - s.cursorState = cursorStateMidBlock -} - -func (s *consoleStenographer) endBlock() { - s.printDelimiter() - s.cursorState = cursorStateEndBlock -} - -func (s *consoleStenographer) stream() { - s.cursorState = cursorStateStreaming -} - -func (s *consoleStenographer) printBlockWithMessage(header string, message string, spec *types.SpecSummary, succinct bool) { - s.startBlock() - s.println(0, header) - - indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, types.SpecComponentTypeInvalid, 0, spec.State, succinct) - - if message != "" { - s.printNewLine() - s.println(indentation, message) - } - - s.endBlock() -} - -func (s *consoleStenographer) printSpecFailure(message string, spec *types.SpecSummary, succinct bool, fullTrace bool) { - s.startBlock() - s.println(0, s.colorize(redColor+boldStyle, "%s%s [%.3f seconds]", message, s.failureContext(spec.Failure.ComponentType), spec.RunTime.Seconds())) - - indentation := s.printCodeLocationBlock(spec.ComponentTexts, spec.ComponentCodeLocations, spec.Failure.ComponentType, spec.Failure.ComponentIndex, spec.State, succinct) - - s.printNewLine() - s.printFailure(indentation, spec.State, spec.Failure, fullTrace) - s.endBlock() -} - -func (s *consoleStenographer) failureContext(failedComponentType types.SpecComponentType) string { - switch failedComponentType { - case types.SpecComponentTypeBeforeSuite: - return " in Suite Setup (BeforeSuite)" - case types.SpecComponentTypeAfterSuite: - return " in Suite Teardown (AfterSuite)" - case types.SpecComponentTypeBeforeEach: - return " in Spec Setup (BeforeEach)" - case types.SpecComponentTypeJustBeforeEach: - return " in Spec Setup (JustBeforeEach)" - case types.SpecComponentTypeAfterEach: - return " in Spec Teardown (AfterEach)" - } - - return "" -} - -func (s *consoleStenographer) printSkip(indentation int, spec types.SpecFailure) { - s.println(indentation, s.colorize(cyanColor, spec.Message)) - s.printNewLine() - s.println(indentation, spec.Location.String()) -} - -func (s *consoleStenographer) printFailure(indentation int, state types.SpecState, failure types.SpecFailure, fullTrace bool) { - if state == types.SpecStatePanicked { - s.println(indentation, s.colorize(redColor+boldStyle, failure.Message)) - s.println(indentation, s.colorize(redColor, failure.ForwardedPanic)) - s.println(indentation, failure.Location.String()) - s.printNewLine() - s.println(indentation, s.colorize(redColor, "Full Stack Trace")) - s.println(indentation, failure.Location.FullStackTrace) - } else { - s.println(indentation, s.colorize(redColor, failure.Message)) - s.printNewLine() - s.println(indentation, failure.Location.String()) - if fullTrace { - s.printNewLine() - s.println(indentation, s.colorize(redColor, "Full Stack Trace")) - s.println(indentation, failure.Location.FullStackTrace) - } - } -} - -func (s *consoleStenographer) printSpecContext(componentTexts []string, componentCodeLocations []types.CodeLocation, failedComponentType types.SpecComponentType, failedComponentIndex int, state types.SpecState, succinct bool) int { - startIndex := 1 - indentation := 0 - - if len(componentTexts) == 1 { - startIndex = 0 - } - - for i := startIndex; i < len(componentTexts); i++ { - if (state.IsFailure() || state == types.SpecStateSkipped) && i == failedComponentIndex { - color := redColor - if state == types.SpecStateSkipped { - color = cyanColor - } - blockType := "" - switch failedComponentType { - case types.SpecComponentTypeBeforeSuite: - blockType = "BeforeSuite" - case types.SpecComponentTypeAfterSuite: - blockType = "AfterSuite" - case types.SpecComponentTypeBeforeEach: - blockType = "BeforeEach" - case types.SpecComponentTypeJustBeforeEach: - blockType = "JustBeforeEach" - case types.SpecComponentTypeAfterEach: - blockType = "AfterEach" - case types.SpecComponentTypeIt: - blockType = "It" - case types.SpecComponentTypeMeasure: - blockType = "Measurement" - } - if succinct { - s.print(0, s.colorize(color+boldStyle, "[%s] %s ", blockType, componentTexts[i])) - } else { - s.println(indentation, s.colorize(color+boldStyle, "%s [%s]", componentTexts[i], blockType)) - s.println(indentation, s.colorize(grayColor, "%s", componentCodeLocations[i])) - } - } else { - if succinct { - s.print(0, s.colorize(alternatingColors[i%2], "%s ", componentTexts[i])) - } else { - s.println(indentation, componentTexts[i]) - s.println(indentation, s.colorize(grayColor, "%s", componentCodeLocations[i])) - } - } - indentation++ - } - - return indentation -} - -func (s *consoleStenographer) printCodeLocationBlock(componentTexts []string, componentCodeLocations []types.CodeLocation, failedComponentType types.SpecComponentType, failedComponentIndex int, state types.SpecState, succinct bool) int { - indentation := s.printSpecContext(componentTexts, componentCodeLocations, failedComponentType, failedComponentIndex, state, succinct) - - if succinct { - if len(componentTexts) > 0 { - s.printNewLine() - s.print(0, s.colorize(lightGrayColor, "%s", componentCodeLocations[len(componentCodeLocations)-1])) - } - s.printNewLine() - indentation = 1 - } else { - indentation-- - } - - return indentation -} - -func (s *consoleStenographer) orderedMeasurementKeys(measurements map[string]*types.SpecMeasurement) []string { - orderedKeys := make([]string, len(measurements)) - for key, measurement := range measurements { - orderedKeys[measurement.Order] = key - } - return orderedKeys -} - -func (s *consoleStenographer) measurementReport(spec *types.SpecSummary, succinct bool) string { - if len(spec.Measurements) == 0 { - return "Found no measurements" - } - - message := []string{} - orderedKeys := s.orderedMeasurementKeys(spec.Measurements) - - if succinct { - message = append(message, fmt.Sprintf("%s samples:", s.colorize(boldStyle, "%d", spec.NumberOfSamples))) - for _, key := range orderedKeys { - measurement := spec.Measurements[key] - message = append(message, fmt.Sprintf(" %s - %s: %s%s, %s: %s%s ± %s%s, %s: %s%s", - s.colorize(boldStyle, "%s", measurement.Name), - measurement.SmallestLabel, - s.colorize(greenColor, measurement.PrecisionFmt(), measurement.Smallest), - measurement.Units, - measurement.AverageLabel, - s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.Average), - measurement.Units, - s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.StdDeviation), - measurement.Units, - measurement.LargestLabel, - s.colorize(redColor, measurement.PrecisionFmt(), measurement.Largest), - measurement.Units, - )) - } - } else { - message = append(message, fmt.Sprintf("Ran %s samples:", s.colorize(boldStyle, "%d", spec.NumberOfSamples))) - for _, key := range orderedKeys { - measurement := spec.Measurements[key] - info := "" - if measurement.Info != nil { - message = append(message, fmt.Sprintf("%v", measurement.Info)) - } - - message = append(message, fmt.Sprintf("%s:\n%s %s: %s%s\n %s: %s%s\n %s: %s%s ± %s%s", - s.colorize(boldStyle, "%s", measurement.Name), - info, - measurement.SmallestLabel, - s.colorize(greenColor, measurement.PrecisionFmt(), measurement.Smallest), - measurement.Units, - measurement.LargestLabel, - s.colorize(redColor, measurement.PrecisionFmt(), measurement.Largest), - measurement.Units, - measurement.AverageLabel, - s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.Average), - measurement.Units, - s.colorize(cyanColor, measurement.PrecisionFmt(), measurement.StdDeviation), - measurement.Units, - )) - } - } - - return strings.Join(message, "\n") -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md deleted file mode 100644 index e84226a7358..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/README.md +++ /dev/null @@ -1,43 +0,0 @@ -# go-colorable - -Colorable writer for windows. - -For example, most of logger packages doesn't show colors on windows. (I know we can do it with ansicon. But I don't want.) -This package is possible to handle escape sequence for ansi color on windows. - -## Too Bad! - -![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/bad.png) - - -## So Good! - -![](https://raw.githubusercontent.com/mattn/go-colorable/gh-pages/good.png) - -## Usage - -```go -logrus.SetFormatter(&logrus.TextFormatter{ForceColors: true}) -logrus.SetOutput(colorable.NewColorableStdout()) - -logrus.Info("succeeded") -logrus.Warn("not correct") -logrus.Error("something error") -logrus.Fatal("panic") -``` - -You can compile above code on non-windows OSs. - -## Installation - -``` -$ go get github.com/mattn/go-colorable -``` - -# License - -MIT - -# Author - -Yasuhiro Matsumoto (a.k.a mattn) diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go deleted file mode 100644 index 52d6653b34b..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_others.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !windows - -package colorable - -import ( - "io" - "os" -) - -func NewColorable(file *os.File) io.Writer { - if file == nil { - panic("nil passed instead of *os.File to NewColorable()") - } - - return file -} - -func NewColorableStdout() io.Writer { - return os.Stdout -} - -func NewColorableStderr() io.Writer { - return os.Stderr -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go deleted file mode 100644 index fb976dbd8b7..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/noncolorable.go +++ /dev/null @@ -1,57 +0,0 @@ -package colorable - -import ( - "bytes" - "fmt" - "io" -) - -type NonColorable struct { - out io.Writer - lastbuf bytes.Buffer -} - -func NewNonColorable(w io.Writer) io.Writer { - return &NonColorable{out: w} -} - -func (w *NonColorable) Write(data []byte) (n int, err error) { - er := bytes.NewBuffer(data) -loop: - for { - c1, _, err := er.ReadRune() - if err != nil { - break loop - } - if c1 != 0x1b { - fmt.Fprint(w.out, string(c1)) - continue - } - c2, _, err := er.ReadRune() - if err != nil { - w.lastbuf.WriteRune(c1) - break loop - } - if c2 != 0x5b { - w.lastbuf.WriteRune(c1) - w.lastbuf.WriteRune(c2) - continue - } - - var buf bytes.Buffer - for { - c, _, err := er.ReadRune() - if err != nil { - w.lastbuf.WriteRune(c1) - w.lastbuf.WriteRune(c2) - w.lastbuf.Write(buf.Bytes()) - break loop - } - if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '@' { - break - } - buf.Write([]byte(string(c))) - } - } - return len(data) - w.lastbuf.Len(), nil -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE deleted file mode 100644 index 65dc692b6b1..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/LICENSE +++ /dev/null @@ -1,9 +0,0 @@ -Copyright (c) Yasuhiro MATSUMOTO - -MIT License (Expat) - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md deleted file mode 100644 index 74845de4a24..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# go-isatty - -isatty for golang - -## Usage - -```go -package main - -import ( - "fmt" - "github.com/mattn/go-isatty" - "os" -) - -func main() { - if isatty.IsTerminal(os.Stdout.Fd()) { - fmt.Println("Is Terminal") - } else { - fmt.Println("Is Not Terminal") - } -} -``` - -## Installation - -``` -$ go get github.com/mattn/go-isatty -``` - -# License - -MIT - -# Author - -Yasuhiro Matsumoto (a.k.a mattn) diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go deleted file mode 100644 index 17d4f90ebcc..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package isatty implements interface to isatty -package isatty diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go deleted file mode 100644 index 83c588773cf..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_appengine.go +++ /dev/null @@ -1,9 +0,0 @@ -// +build appengine - -package isatty - -// IsTerminal returns true if the file descriptor is terminal which -// is always false on on appengine classic which is a sandboxed PaaS. -func IsTerminal(fd uintptr) bool { - return false -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go deleted file mode 100644 index 98ffe86a4ba..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_bsd.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build darwin freebsd openbsd netbsd -// +build !appengine - -package isatty - -import ( - "syscall" - "unsafe" -) - -const ioctlReadTermios = syscall.TIOCGETA - -// IsTerminal return true if the file descriptor is terminal. -func IsTerminal(fd uintptr) bool { - var termios syscall.Termios - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) - return err == 0 -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go deleted file mode 100644 index 9d24bac1db3..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_linux.go +++ /dev/null @@ -1,18 +0,0 @@ -// +build linux -// +build !appengine - -package isatty - -import ( - "syscall" - "unsafe" -) - -const ioctlReadTermios = syscall.TCGETS - -// IsTerminal return true if the file descriptor is terminal. -func IsTerminal(fd uintptr) bool { - var termios syscall.Termios - _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) - return err == 0 -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go deleted file mode 100644 index 1f0c6bf53dc..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_solaris.go +++ /dev/null @@ -1,16 +0,0 @@ -// +build solaris -// +build !appengine - -package isatty - -import ( - "golang.org/x/sys/unix" -) - -// IsTerminal returns true if the given file descriptor is a terminal. -// see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c -func IsTerminal(fd uintptr) bool { - var termio unix.Termio - err := unix.IoctlSetTermio(int(fd), unix.TCGETA, &termio) - return err == nil -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go b/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go deleted file mode 100644 index 83c398b16db..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty/isatty_windows.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build windows -// +build !appengine - -package isatty - -import ( - "syscall" - "unsafe" -) - -var kernel32 = syscall.NewLazyDLL("kernel32.dll") -var procGetConsoleMode = kernel32.NewProc("GetConsoleMode") - -// IsTerminal return true if the file descriptor is terminal. -func IsTerminal(fd uintptr) bool { - var st uint32 - r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, fd, uintptr(unsafe.Pointer(&st)), 0) - return r != 0 && e == 0 -} diff --git a/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go b/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go deleted file mode 100644 index 84fd8aff878..00000000000 --- a/vendor/github.com/onsi/ginkgo/reporters/teamcity_reporter.go +++ /dev/null @@ -1,106 +0,0 @@ -/* - -TeamCity Reporter for Ginkgo - -Makes use of TeamCity's support for Service Messages -http://confluence.jetbrains.com/display/TCD7/Build+Script+Interaction+with+TeamCity#BuildScriptInteractionwithTeamCity-ReportingTests -*/ - -package reporters - -import ( - "fmt" - "io" - "strings" - - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/types" -) - -const ( - messageId = "##teamcity" -) - -type TeamCityReporter struct { - writer io.Writer - testSuiteName string - ReporterConfig config.DefaultReporterConfigType -} - -func NewTeamCityReporter(writer io.Writer) *TeamCityReporter { - return &TeamCityReporter{ - writer: writer, - } -} - -func (reporter *TeamCityReporter) SpecSuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) { - reporter.testSuiteName = escape(summary.SuiteDescription) - fmt.Fprintf(reporter.writer, "%s[testSuiteStarted name='%s']\n", messageId, reporter.testSuiteName) -} - -func (reporter *TeamCityReporter) BeforeSuiteDidRun(setupSummary *types.SetupSummary) { - reporter.handleSetupSummary("BeforeSuite", setupSummary) -} - -func (reporter *TeamCityReporter) AfterSuiteDidRun(setupSummary *types.SetupSummary) { - reporter.handleSetupSummary("AfterSuite", setupSummary) -} - -func (reporter *TeamCityReporter) handleSetupSummary(name string, setupSummary *types.SetupSummary) { - if setupSummary.State != types.SpecStatePassed { - testName := escape(name) - fmt.Fprintf(reporter.writer, "%s[testStarted name='%s']\n", messageId, testName) - message := reporter.failureMessage(setupSummary.Failure) - details := reporter.failureDetails(setupSummary.Failure) - fmt.Fprintf(reporter.writer, "%s[testFailed name='%s' message='%s' details='%s']\n", messageId, testName, message, details) - durationInMilliseconds := setupSummary.RunTime.Seconds() * 1000 - fmt.Fprintf(reporter.writer, "%s[testFinished name='%s' duration='%v']\n", messageId, testName, durationInMilliseconds) - } -} - -func (reporter *TeamCityReporter) SpecWillRun(specSummary *types.SpecSummary) { - testName := escape(strings.Join(specSummary.ComponentTexts[1:], " ")) - fmt.Fprintf(reporter.writer, "%s[testStarted name='%s']\n", messageId, testName) -} - -func (reporter *TeamCityReporter) SpecDidComplete(specSummary *types.SpecSummary) { - testName := escape(strings.Join(specSummary.ComponentTexts[1:], " ")) - - if reporter.ReporterConfig.ReportPassed && specSummary.State == types.SpecStatePassed { - details := escape(specSummary.CapturedOutput) - fmt.Fprintf(reporter.writer, "%s[testPassed name='%s' details='%s']\n", messageId, testName, details) - } - if specSummary.State == types.SpecStateFailed || specSummary.State == types.SpecStateTimedOut || specSummary.State == types.SpecStatePanicked { - message := reporter.failureMessage(specSummary.Failure) - details := reporter.failureDetails(specSummary.Failure) - fmt.Fprintf(reporter.writer, "%s[testFailed name='%s' message='%s' details='%s']\n", messageId, testName, message, details) - } - if specSummary.State == types.SpecStateSkipped || specSummary.State == types.SpecStatePending { - fmt.Fprintf(reporter.writer, "%s[testIgnored name='%s']\n", messageId, testName) - } - - durationInMilliseconds := specSummary.RunTime.Seconds() * 1000 - fmt.Fprintf(reporter.writer, "%s[testFinished name='%s' duration='%v']\n", messageId, testName, durationInMilliseconds) -} - -func (reporter *TeamCityReporter) SpecSuiteDidEnd(summary *types.SuiteSummary) { - fmt.Fprintf(reporter.writer, "%s[testSuiteFinished name='%s']\n", messageId, reporter.testSuiteName) -} - -func (reporter *TeamCityReporter) failureMessage(failure types.SpecFailure) string { - return escape(failure.ComponentCodeLocation.String()) -} - -func (reporter *TeamCityReporter) failureDetails(failure types.SpecFailure) string { - return escape(fmt.Sprintf("%s\n%s", failure.Message, failure.Location.String())) -} - -func escape(output string) string { - output = strings.Replace(output, "|", "||", -1) - output = strings.Replace(output, "'", "|'", -1) - output = strings.Replace(output, "\n", "|n", -1) - output = strings.Replace(output, "\r", "|r", -1) - output = strings.Replace(output, "[", "|[", -1) - output = strings.Replace(output, "]", "|]", -1) - return output -} diff --git a/vendor/github.com/onsi/ginkgo/types/code_location.go b/vendor/github.com/onsi/ginkgo/types/code_location.go deleted file mode 100644 index 935a89e136a..00000000000 --- a/vendor/github.com/onsi/ginkgo/types/code_location.go +++ /dev/null @@ -1,15 +0,0 @@ -package types - -import ( - "fmt" -) - -type CodeLocation struct { - FileName string - LineNumber int - FullStackTrace string -} - -func (codeLocation CodeLocation) String() string { - return fmt.Sprintf("%s:%d", codeLocation.FileName, codeLocation.LineNumber) -} diff --git a/vendor/github.com/onsi/ginkgo/types/synchronization.go b/vendor/github.com/onsi/ginkgo/types/synchronization.go deleted file mode 100644 index fdd6ed5bdf8..00000000000 --- a/vendor/github.com/onsi/ginkgo/types/synchronization.go +++ /dev/null @@ -1,30 +0,0 @@ -package types - -import ( - "encoding/json" -) - -type RemoteBeforeSuiteState int - -const ( - RemoteBeforeSuiteStateInvalid RemoteBeforeSuiteState = iota - - RemoteBeforeSuiteStatePending - RemoteBeforeSuiteStatePassed - RemoteBeforeSuiteStateFailed - RemoteBeforeSuiteStateDisappeared -) - -type RemoteBeforeSuiteData struct { - Data []byte - State RemoteBeforeSuiteState -} - -func (r RemoteBeforeSuiteData) ToJSON() []byte { - data, _ := json.Marshal(r) - return data -} - -type RemoteAfterSuiteData struct { - CanRun bool -} diff --git a/vendor/github.com/onsi/ginkgo/types/types.go b/vendor/github.com/onsi/ginkgo/types/types.go deleted file mode 100644 index c143e02d845..00000000000 --- a/vendor/github.com/onsi/ginkgo/types/types.go +++ /dev/null @@ -1,174 +0,0 @@ -package types - -import ( - "strconv" - "time" -) - -const GINKGO_FOCUS_EXIT_CODE = 197 - -/* -SuiteSummary represents the a summary of the test suite and is passed to both -Reporter.SpecSuiteWillBegin -Reporter.SpecSuiteDidEnd - -this is unfortunate as these two methods should receive different objects. When running in parallel -each node does not deterministically know how many specs it will end up running. - -Unfortunately making such a change would break backward compatibility. - -Until Ginkgo 2.0 comes out we will continue to reuse this struct but populate unknown fields -with -1. -*/ -type SuiteSummary struct { - SuiteDescription string - SuiteSucceeded bool - SuiteID string - - NumberOfSpecsBeforeParallelization int - NumberOfTotalSpecs int - NumberOfSpecsThatWillBeRun int - NumberOfPendingSpecs int - NumberOfSkippedSpecs int - NumberOfPassedSpecs int - NumberOfFailedSpecs int - // Flaked specs are those that failed initially, but then passed on a - // subsequent try. - NumberOfFlakedSpecs int - RunTime time.Duration -} - -type SpecSummary struct { - ComponentTexts []string - ComponentCodeLocations []CodeLocation - - State SpecState - RunTime time.Duration - Failure SpecFailure - IsMeasurement bool - NumberOfSamples int - Measurements map[string]*SpecMeasurement - - CapturedOutput string - SuiteID string -} - -func (s SpecSummary) HasFailureState() bool { - return s.State.IsFailure() -} - -func (s SpecSummary) TimedOut() bool { - return s.State == SpecStateTimedOut -} - -func (s SpecSummary) Panicked() bool { - return s.State == SpecStatePanicked -} - -func (s SpecSummary) Failed() bool { - return s.State == SpecStateFailed -} - -func (s SpecSummary) Passed() bool { - return s.State == SpecStatePassed -} - -func (s SpecSummary) Skipped() bool { - return s.State == SpecStateSkipped -} - -func (s SpecSummary) Pending() bool { - return s.State == SpecStatePending -} - -type SetupSummary struct { - ComponentType SpecComponentType - CodeLocation CodeLocation - - State SpecState - RunTime time.Duration - Failure SpecFailure - - CapturedOutput string - SuiteID string -} - -type SpecFailure struct { - Message string - Location CodeLocation - ForwardedPanic string - - ComponentIndex int - ComponentType SpecComponentType - ComponentCodeLocation CodeLocation -} - -type SpecMeasurement struct { - Name string - Info interface{} - Order int - - Results []float64 - - Smallest float64 - Largest float64 - Average float64 - StdDeviation float64 - - SmallestLabel string - LargestLabel string - AverageLabel string - Units string - Precision int -} - -func (s SpecMeasurement) PrecisionFmt() string { - if s.Precision == 0 { - return "%f" - } - - str := strconv.Itoa(s.Precision) - - return "%." + str + "f" -} - -type SpecState uint - -const ( - SpecStateInvalid SpecState = iota - - SpecStatePending - SpecStateSkipped - SpecStatePassed - SpecStateFailed - SpecStatePanicked - SpecStateTimedOut -) - -func (state SpecState) IsFailure() bool { - return state == SpecStateTimedOut || state == SpecStatePanicked || state == SpecStateFailed -} - -type SpecComponentType uint - -const ( - SpecComponentTypeInvalid SpecComponentType = iota - - SpecComponentTypeContainer - SpecComponentTypeBeforeSuite - SpecComponentTypeAfterSuite - SpecComponentTypeBeforeEach - SpecComponentTypeJustBeforeEach - SpecComponentTypeJustAfterEach - SpecComponentTypeAfterEach - SpecComponentTypeIt - SpecComponentTypeMeasure -) - -type FlagType uint - -const ( - FlagTypeNone FlagType = iota - FlagTypeFocused - FlagTypePending -) diff --git a/vendor/github.com/onsi/ginkgo/LICENSE b/vendor/github.com/onsi/ginkgo/v2/LICENSE similarity index 100% rename from vendor/github.com/onsi/ginkgo/LICENSE rename to vendor/github.com/onsi/ginkgo/v2/LICENSE diff --git a/vendor/github.com/onsi/ginkgo/v2/config/deprecated.go b/vendor/github.com/onsi/ginkgo/v2/config/deprecated.go new file mode 100644 index 00000000000..a61021d0889 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/config/deprecated.go @@ -0,0 +1,69 @@ +package config + +// GinkgoConfigType has been deprecated and its equivalent now lives in +// the types package. You can no longer access Ginkgo configuration from the config +// package. Instead use the DSL's GinkgoConfiguration() function to get copies of the +// current configuration +// +// GinkgoConfigType is still here so custom V1 reporters do not result in a compilation error +// It will be removed in a future minor release of Ginkgo +type GinkgoConfigType = DeprecatedGinkgoConfigType +type DeprecatedGinkgoConfigType struct { + RandomSeed int64 + RandomizeAllSpecs bool + RegexScansFilePath bool + FocusStrings []string + SkipStrings []string + SkipMeasurements bool + FailOnPending bool + FailFast bool + FlakeAttempts int + EmitSpecProgress bool + DryRun bool + DebugParallel bool + + ParallelNode int + ParallelTotal int + SyncHost string + StreamHost string +} + +// DefaultReporterConfigType has been deprecated and its equivalent now lives in +// the types package. You can no longer access Ginkgo configuration from the config +// package. Instead use the DSL's GinkgoConfiguration() function to get copies of the +// current configuration +// +// DefaultReporterConfigType is still here so custom V1 reporters do not result in a compilation error +// It will be removed in a future minor release of Ginkgo +type DefaultReporterConfigType = DeprecatedDefaultReporterConfigType +type DeprecatedDefaultReporterConfigType struct { + NoColor bool + SlowSpecThreshold float64 + NoisyPendings bool + NoisySkippings bool + Succinct bool + Verbose bool + FullTrace bool + ReportPassed bool + ReportFile string +} + +// Sadly there is no way to gracefully deprecate access to these global config variables. +// Users who need access to Ginkgo's configuration should use the DSL's GinkgoConfiguration() method +// These new unwieldy type names exist to give users a hint when they try to compile and the compilation fails +type GinkgoConfigIsNoLongerAccessibleFromTheConfigPackageUseTheDSLsGinkgoConfigurationFunctionInstead struct{} + +// Sadly there is no way to gracefully deprecate access to these global config variables. +// Users who need access to Ginkgo's configuration should use the DSL's GinkgoConfiguration() method +// These new unwieldy type names exist to give users a hint when they try to compile and the compilation fails +var GinkgoConfig = GinkgoConfigIsNoLongerAccessibleFromTheConfigPackageUseTheDSLsGinkgoConfigurationFunctionInstead{} + +// Sadly there is no way to gracefully deprecate access to these global config variables. +// Users who need access to Ginkgo's configuration should use the DSL's GinkgoConfiguration() method +// These new unwieldy type names exist to give users a hint when they try to compile and the compilation fails +type DefaultReporterConfigIsNoLongerAccessibleFromTheConfigPackageUseTheDSLsGinkgoConfigurationFunctionInstead struct{} + +// Sadly there is no way to gracefully deprecate access to these global config variables. +// Users who need access to Ginkgo's configuration should use the DSL's GinkgoConfiguration() method +// These new unwieldy type names exist to give users a hint when they try to compile and the compilation fails +var DefaultReporterConfig = DefaultReporterConfigIsNoLongerAccessibleFromTheConfigPackageUseTheDSLsGinkgoConfigurationFunctionInstead{} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE b/vendor/github.com/onsi/ginkgo/v2/formatter/colorable_others.go similarity index 76% rename from vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE rename to vendor/github.com/onsi/ginkgo/v2/formatter/colorable_others.go index 91b5cef30eb..778bfd7c7ca 100644 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/LICENSE +++ b/vendor/github.com/onsi/ginkgo/v2/formatter/colorable_others.go @@ -1,3 +1,11 @@ +// +build !windows + +/* +These packages are used for colorize on Windows and contributed by mattn.jp@gmail.com + + * go-colorable: + * go-isatty: + The MIT License (MIT) Copyright (c) 2016 Yasuhiro Matsumoto @@ -19,3 +27,15 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +package formatter + +import ( + "io" + "os" +) + +func newColorable(file *os.File) io.Writer { + return file +} diff --git a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go b/vendor/github.com/onsi/ginkgo/v2/formatter/colorable_windows.go similarity index 90% rename from vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go rename to vendor/github.com/onsi/ginkgo/v2/formatter/colorable_windows.go index 1088009230b..dd1d143cc20 100644 --- a/vendor/github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable/colorable_windows.go +++ b/vendor/github.com/onsi/ginkgo/v2/formatter/colorable_windows.go @@ -1,4 +1,33 @@ -package colorable +/* +These packages are used for colorize on Windows and contributed by mattn.jp@gmail.com + + * go-colorable: + * go-isatty: + +The MIT License (MIT) + +Copyright (c) 2016 Yasuhiro Matsumoto + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package formatter import ( "bytes" @@ -10,10 +39,24 @@ import ( "strings" "syscall" "unsafe" +) - "github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty" +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") + procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute") + procSetConsoleCursorPosition = kernel32.NewProc("SetConsoleCursorPosition") + procFillConsoleOutputCharacter = kernel32.NewProc("FillConsoleOutputCharacterW") + procFillConsoleOutputAttribute = kernel32.NewProc("FillConsoleOutputAttribute") + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") ) +func isTerminal(fd uintptr) bool { + var st uint32 + r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, fd, uintptr(unsafe.Pointer(&st)), 0) + return r != 0 && e == 0 +} + const ( foregroundBlue = 0x1 foregroundGreen = 0x2 @@ -52,45 +95,28 @@ type consoleScreenBufferInfo struct { maximumWindowSize coord } -var ( - kernel32 = syscall.NewLazyDLL("kernel32.dll") - procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") - procSetConsoleTextAttribute = kernel32.NewProc("SetConsoleTextAttribute") - procSetConsoleCursorPosition = kernel32.NewProc("SetConsoleCursorPosition") - procFillConsoleOutputCharacter = kernel32.NewProc("FillConsoleOutputCharacterW") - procFillConsoleOutputAttribute = kernel32.NewProc("FillConsoleOutputAttribute") -) - -type Writer struct { +type writer struct { out io.Writer handle syscall.Handle lastbuf bytes.Buffer oldattr word } -func NewColorable(file *os.File) io.Writer { +func newColorable(file *os.File) io.Writer { if file == nil { panic("nil passed instead of *os.File to NewColorable()") } - if isatty.IsTerminal(file.Fd()) { + if isTerminal(file.Fd()) { var csbi consoleScreenBufferInfo handle := syscall.Handle(file.Fd()) procGetConsoleScreenBufferInfo.Call(uintptr(handle), uintptr(unsafe.Pointer(&csbi))) - return &Writer{out: file, handle: handle, oldattr: csbi.attributes} + return &writer{out: file, handle: handle, oldattr: csbi.attributes} } else { return file } } -func NewColorableStdout() io.Writer { - return NewColorable(os.Stdout) -} - -func NewColorableStderr() io.Writer { - return NewColorable(os.Stderr) -} - var color256 = map[int]int{ 0: 0x000000, 1: 0x800000, @@ -350,7 +376,7 @@ var color256 = map[int]int{ 255: 0xeeeeee, } -func (w *Writer) Write(data []byte) (n int, err error) { +func (w *writer) Write(data []byte) (n int, err error) { var csbi consoleScreenBufferInfo procGetConsoleScreenBufferInfo.Call(uintptr(w.handle), uintptr(unsafe.Pointer(&csbi))) diff --git a/vendor/github.com/onsi/ginkgo/formatter/formatter.go b/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go similarity index 91% rename from vendor/github.com/onsi/ginkgo/formatter/formatter.go rename to vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go index 30d7cbe1297..43b16211d8d 100644 --- a/vendor/github.com/onsi/ginkgo/formatter/formatter.go +++ b/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go @@ -2,10 +2,15 @@ package formatter import ( "fmt" + "os" "regexp" "strings" ) +// ColorableStdOut and ColorableStdErr enable color output support on Windows +var ColorableStdOut = newColorable(os.Stdout) +var ColorableStdErr = newColorable(os.Stderr) + const COLS = 80 type ColorMode uint8 @@ -100,13 +105,13 @@ func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...i outLines = append(outLines, line) continue } - outWords := []string{} - length := uint(0) words := strings.Split(line, " ") - for _, word := range words { + outWords := []string{words[0]} + length := uint(f.length(words[0])) + for _, word := range words[1:] { wordLength := f.length(word) - if length+wordLength <= maxWidth { - length += wordLength + if length+wordLength+1 <= maxWidth { + length += wordLength + 1 outWords = append(outWords, word) continue } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go new file mode 100644 index 00000000000..5db5d1a7bfd --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go @@ -0,0 +1,63 @@ +package build + +import ( + "fmt" + + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/types" +) + +func BuildBuildCommand() command.Command { + var cliConfig = types.NewDefaultCLIConfig() + var goFlagsConfig = types.NewDefaultGoFlagsConfig() + + flags, err := types.BuildBuildCommandFlagSet(&cliConfig, &goFlagsConfig) + if err != nil { + panic(err) + } + + return command.Command{ + Name: "build", + Flags: flags, + Usage: "ginkgo build ", + ShortDoc: "Build the passed in (or the package in the current directory if left blank).", + DocLink: "precompiling-suites", + Command: func(args []string, _ []string) { + var errors []error + cliConfig, goFlagsConfig, errors = types.VetAndInitializeCLIAndGoConfig(cliConfig, goFlagsConfig) + command.AbortIfErrors("Ginkgo detected configuration issues:", errors) + + buildSpecs(args, cliConfig, goFlagsConfig) + }, + } +} + +func buildSpecs(args []string, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) { + suites := internal.FindSuites(args, cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) + if len(suites) == 0 { + command.AbortWith("Found no test suites") + } + + internal.VerifyCLIAndFrameworkVersion(suites) + + opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers()) + opc.StartCompiling(suites, goFlagsConfig) + + for { + suiteIdx, suite := opc.Next() + if suiteIdx >= len(suites) { + break + } + suites[suiteIdx] = suite + if suite.State.Is(internal.TestSuiteStateFailedToCompile) { + fmt.Println(suite.CompilationError.Error()) + } else { + fmt.Printf("Compiled %s.test\n", suite.PackageName) + } + } + + if suites.CountWithState(internal.TestSuiteStateFailedToCompile) > 0 { + command.AbortWith("Failed to compile all tests") + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/abort.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/abort.go new file mode 100644 index 00000000000..2efd2860886 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/abort.go @@ -0,0 +1,61 @@ +package command + +import "fmt" + +type AbortDetails struct { + ExitCode int + Error error + EmitUsage bool +} + +func Abort(details AbortDetails) { + panic(details) +} + +func AbortGracefullyWith(format string, args ...interface{}) { + Abort(AbortDetails{ + ExitCode: 0, + Error: fmt.Errorf(format, args...), + EmitUsage: false, + }) +} + +func AbortWith(format string, args ...interface{}) { + Abort(AbortDetails{ + ExitCode: 1, + Error: fmt.Errorf(format, args...), + EmitUsage: false, + }) +} + +func AbortWithUsage(format string, args ...interface{}) { + Abort(AbortDetails{ + ExitCode: 1, + Error: fmt.Errorf(format, args...), + EmitUsage: true, + }) +} + +func AbortIfError(preamble string, err error) { + if err != nil { + Abort(AbortDetails{ + ExitCode: 1, + Error: fmt.Errorf("%s\n%s", preamble, err.Error()), + EmitUsage: false, + }) + } +} + +func AbortIfErrors(preamble string, errors []error) { + if len(errors) > 0 { + out := "" + for _, err := range errors { + out += err.Error() + } + Abort(AbortDetails{ + ExitCode: 1, + Error: fmt.Errorf("%s\n%s", preamble, out), + EmitUsage: false, + }) + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go new file mode 100644 index 00000000000..12e0e565914 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/command.go @@ -0,0 +1,50 @@ +package command + +import ( + "fmt" + "io" + "strings" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/types" +) + +type Command struct { + Name string + Flags types.GinkgoFlagSet + Usage string + ShortDoc string + Documentation string + DocLink string + Command func(args []string, additionalArgs []string) +} + +func (c Command) Run(args []string, additionalArgs []string) { + args, err := c.Flags.Parse(args) + if err != nil { + AbortWithUsage(err.Error()) + } + + c.Command(args, additionalArgs) +} + +func (c Command) EmitUsage(writer io.Writer) { + fmt.Fprintln(writer, formatter.F("{{bold}}"+c.Usage+"{{/}}")) + fmt.Fprintln(writer, formatter.F("{{gray}}%s{{/}}", strings.Repeat("-", len(c.Usage)))) + if c.ShortDoc != "" { + fmt.Fprintln(writer, formatter.Fiw(0, formatter.COLS, c.ShortDoc)) + fmt.Fprintln(writer, "") + } + if c.Documentation != "" { + fmt.Fprintln(writer, formatter.Fiw(0, formatter.COLS, c.Documentation)) + fmt.Fprintln(writer, "") + } + if c.DocLink != "" { + fmt.Fprintln(writer, formatter.Fi(0, "{{bold}}Learn more at:{{/}} {{cyan}}{{underline}}http://onsi.github.io/ginkgo/#%s{{/}}", c.DocLink)) + fmt.Fprintln(writer, "") + } + flagUsage := c.Flags.Usage() + if flagUsage != "" { + fmt.Fprintf(writer, formatter.F(flagUsage)) + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/program.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/program.go new file mode 100644 index 00000000000..88dd8d6b07e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/command/program.go @@ -0,0 +1,182 @@ +package command + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/types" +) + +type Program struct { + Name string + Heading string + Commands []Command + DefaultCommand Command + DeprecatedCommands []DeprecatedCommand + + //For testing - leave as nil in production + OutWriter io.Writer + ErrWriter io.Writer + Exiter func(code int) +} + +type DeprecatedCommand struct { + Name string + Deprecation types.Deprecation +} + +func (p Program) RunAndExit(osArgs []string) { + var command Command + deprecationTracker := types.NewDeprecationTracker() + if p.Exiter == nil { + p.Exiter = os.Exit + } + if p.OutWriter == nil { + p.OutWriter = formatter.ColorableStdOut + } + if p.ErrWriter == nil { + p.ErrWriter = formatter.ColorableStdErr + } + + defer func() { + exitCode := 0 + + if r := recover(); r != nil { + details, ok := r.(AbortDetails) + if !ok { + panic(r) + } + + if details.Error != nil { + fmt.Fprintln(p.ErrWriter, formatter.F("{{red}}{{bold}}%s %s{{/}} {{red}}failed{{/}}", p.Name, command.Name)) + fmt.Fprintln(p.ErrWriter, formatter.Fi(1, details.Error.Error())) + } + if details.EmitUsage { + if details.Error != nil { + fmt.Fprintln(p.ErrWriter, "") + } + command.EmitUsage(p.ErrWriter) + } + exitCode = details.ExitCode + } + + command.Flags.ValidateDeprecations(deprecationTracker) + if deprecationTracker.DidTrackDeprecations() { + fmt.Fprintln(p.ErrWriter, deprecationTracker.DeprecationsReport()) + } + p.Exiter(exitCode) + return + }() + + args, additionalArgs := []string{}, []string{} + + foundDelimiter := false + for _, arg := range osArgs[1:] { + if !foundDelimiter { + if arg == "--" { + foundDelimiter = true + continue + } + } + + if foundDelimiter { + additionalArgs = append(additionalArgs, arg) + } else { + args = append(args, arg) + } + } + + command = p.DefaultCommand + if len(args) > 0 { + p.handleHelpRequestsAndExit(p.OutWriter, args) + if command.Name == args[0] { + args = args[1:] + } else { + for _, deprecatedCommand := range p.DeprecatedCommands { + if deprecatedCommand.Name == args[0] { + deprecationTracker.TrackDeprecation(deprecatedCommand.Deprecation) + return + } + } + for _, tryCommand := range p.Commands { + if tryCommand.Name == args[0] { + command, args = tryCommand, args[1:] + break + } + } + } + } + + command.Run(args, additionalArgs) +} + +func (p Program) handleHelpRequestsAndExit(writer io.Writer, args []string) { + if len(args) == 0 { + return + } + + matchesHelpFlag := func(args ...string) bool { + for _, arg := range args { + if arg == "--help" || arg == "-help" || arg == "-h" || arg == "--h" { + return true + } + } + return false + } + if len(args) == 1 { + if args[0] == "help" || matchesHelpFlag(args[0]) { + p.EmitUsage(writer) + Abort(AbortDetails{}) + } + } else { + var name string + if args[0] == "help" || matchesHelpFlag(args[0]) { + name = args[1] + } else if matchesHelpFlag(args[1:]...) { + name = args[0] + } else { + return + } + + if p.DefaultCommand.Name == name || p.Name == name { + p.DefaultCommand.EmitUsage(writer) + Abort(AbortDetails{}) + } + for _, command := range p.Commands { + if command.Name == name { + command.EmitUsage(writer) + Abort(AbortDetails{}) + } + } + + fmt.Fprintln(writer, formatter.F("{{red}}Unknown Command: {{bold}}%s{{/}}", name)) + fmt.Fprintln(writer, "") + p.EmitUsage(writer) + Abort(AbortDetails{ExitCode: 1}) + } + return +} + +func (p Program) EmitUsage(writer io.Writer) { + fmt.Fprintln(writer, formatter.F(p.Heading)) + fmt.Fprintln(writer, formatter.F("{{gray}}%s{{/}}", strings.Repeat("-", len(p.Heading)))) + fmt.Fprintln(writer, formatter.F("For usage information for a command, run {{bold}}%s help COMMAND{{/}}.", p.Name)) + fmt.Fprintln(writer, formatter.F("For usage information for the default command, run {{bold}}%s help %s{{/}} or {{bold}}%s help %s{{/}}.", p.Name, p.Name, p.Name, p.DefaultCommand.Name)) + fmt.Fprintln(writer, "") + fmt.Fprintln(writer, formatter.F("The following commands are available:")) + + fmt.Fprintln(writer, formatter.Fi(1, "{{bold}}%s{{/}} or %s {{bold}}%s{{/}} - {{gray}}%s{{/}}", p.Name, p.Name, p.DefaultCommand.Name, p.DefaultCommand.Usage)) + if p.DefaultCommand.ShortDoc != "" { + fmt.Fprintln(writer, formatter.Fi(2, p.DefaultCommand.ShortDoc)) + } + + for _, command := range p.Commands { + fmt.Fprintln(writer, formatter.Fi(1, "{{bold}}%s{{/}} - {{gray}}%s{{/}}", command.Name, command.Usage)) + if command.ShortDoc != "" { + fmt.Fprintln(writer, formatter.Fi(2, command.ShortDoc)) + } + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/boostrap_templates.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/boostrap_templates.go new file mode 100644 index 00000000000..a367a1fc977 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/boostrap_templates.go @@ -0,0 +1,48 @@ +package generators + +var bootstrapText = `package {{.Package}} + +import ( + "testing" + + {{.GinkgoImport}} + {{.GomegaImport}} +) + +func Test{{.FormattedName}}(t *testing.T) { + {{.GomegaPackage}}RegisterFailHandler({{.GinkgoPackage}}Fail) + {{.GinkgoPackage}}RunSpecs(t, "{{.FormattedName}} Suite") +} +` + +var agoutiBootstrapText = `package {{.Package}} + +import ( + "testing" + + {{.GinkgoImport}} + {{.GomegaImport}} + "github.com/sclevine/agouti" +) + +func Test{{.FormattedName}}(t *testing.T) { + {{.GomegaPackage}}RegisterFailHandler({{.GinkgoPackage}}Fail) + {{.GinkgoPackage}}RunSpecs(t, "{{.FormattedName}} Suite") +} + +var agoutiDriver *agouti.WebDriver + +var _ = {{.GinkgoPackage}}BeforeSuite(func() { + // Choose a WebDriver: + + agoutiDriver = agouti.PhantomJS() + // agoutiDriver = agouti.Selenium() + // agoutiDriver = agouti.ChromeDriver() + + {{.GomegaPackage}}Expect(agoutiDriver.Start()).To({{.GomegaPackage}}Succeed()) +}) + +var _ = {{.GinkgoPackage}}AfterSuite(func() { + {{.GomegaPackage}}Expect(agoutiDriver.Stop()).To({{.GomegaPackage}}Succeed()) +}) +` diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go new file mode 100644 index 00000000000..0273abe9c6e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go @@ -0,0 +1,113 @@ +package generators + +import ( + "bytes" + "fmt" + "os" + "text/template" + + sprig "github.com/go-task/slim-sprig" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/types" +) + +func BuildBootstrapCommand() command.Command { + conf := GeneratorsConfig{} + flags, err := types.NewGinkgoFlagSet( + types.GinkgoFlags{ + {Name: "agouti", KeyPath: "Agouti", + Usage: "If set, bootstrap will generate a bootstrap file for writing Agouti tests"}, + {Name: "nodot", KeyPath: "NoDot", + Usage: "If set, bootstrap will generate a bootstrap test file that does not dot-import ginkgo and gomega"}, + {Name: "internal", KeyPath: "Internal", + Usage: "If set, bootstrap will generate a bootstrap test file that uses the regular package name (i.e. `package X`, not `package X_test`)"}, + {Name: "template", KeyPath: "CustomTemplate", + UsageArgument: "template-file", + Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"}, + }, + &conf, + types.GinkgoFlagSections{}, + ) + + if err != nil { + panic(err) + } + + return command.Command{ + Name: "bootstrap", + Usage: "ginkgo bootstrap", + ShortDoc: "Bootstrap a test suite for the current package", + Documentation: `Tests written in Ginkgo and Gomega require a small amount of boilerplate to hook into Go's testing infrastructure. + +{{bold}}ginkgo bootstrap{{/}} generates this boilerplate for you in a file named X_suite_test.go where X is the name of the package under test.`, + DocLink: "generators", + Flags: flags, + Command: func(_ []string, _ []string) { + generateBootstrap(conf) + }, + } +} + +type bootstrapData struct { + Package string + FormattedName string + + GinkgoImport string + GomegaImport string + GinkgoPackage string + GomegaPackage string +} + +func generateBootstrap(conf GeneratorsConfig) { + packageName, bootstrapFilePrefix, formattedName := getPackageAndFormattedName() + + data := bootstrapData{ + Package: determinePackageName(packageName, conf.Internal), + FormattedName: formattedName, + + GinkgoImport: `. "github.com/onsi/ginkgo/v2"`, + GomegaImport: `. "github.com/onsi/gomega"`, + GinkgoPackage: "", + GomegaPackage: "", + } + + if conf.NoDot { + data.GinkgoImport = `"github.com/onsi/ginkgo/v2"` + data.GomegaImport = `"github.com/onsi/gomega"` + data.GinkgoPackage = `ginkgo.` + data.GomegaPackage = `gomega.` + } + + targetFile := fmt.Sprintf("%s_suite_test.go", bootstrapFilePrefix) + if internal.FileExists(targetFile) { + command.AbortWith("{{bold}}%s{{/}} already exists", targetFile) + } else { + fmt.Printf("Generating ginkgo test suite bootstrap for %s in:\n\t%s\n", packageName, targetFile) + } + + f, err := os.Create(targetFile) + command.AbortIfError("Failed to create file:", err) + defer f.Close() + + var templateText string + if conf.CustomTemplate != "" { + tpl, err := os.ReadFile(conf.CustomTemplate) + command.AbortIfError("Failed to read custom bootstrap file:", err) + templateText = string(tpl) + } else if conf.Agouti { + templateText = agoutiBootstrapText + } else { + templateText = bootstrapText + } + + bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) + command.AbortIfError("Failed to parse bootstrap template:", err) + + buf := &bytes.Buffer{} + bootstrapTemplate.Execute(buf, data) + + buf.WriteTo(f) + + internal.GoFmt(targetFile) +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go similarity index 50% rename from vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go index f792716764e..93b0b4b25b7 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/generate_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go @@ -1,10 +1,8 @@ -package main +package generators import ( "bytes" - "flag" "fmt" - "io/ioutil" "os" "path/filepath" "strconv" @@ -12,167 +10,134 @@ import ( "text/template" sprig "github.com/go-task/slim-sprig" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/types" ) -func BuildGenerateCommand() *Command { - var ( - agouti, noDot, internal bool - customTestFile string - ) - flagSet := flag.NewFlagSet("generate", flag.ExitOnError) - flagSet.BoolVar(&agouti, "agouti", false, "If set, generate will generate a test file for writing Agouti tests") - flagSet.BoolVar(&noDot, "nodot", false, "If set, generate will generate a test file that does not . import ginkgo and gomega") - flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name") - flagSet.StringVar(&customTestFile, "template", "", "If specified, generate will use the contents of the file passed as the test file template") - - return &Command{ - Name: "generate", - FlagSet: flagSet, - UsageCommand: "ginkgo generate ", - Usage: []string{ - "Generate a test file named filename_test.go", - "If the optional argument is omitted, a file named after the package in the current directory will be created.", - "Accepts the following flags:", - }, - Command: func(args []string, additionalArgs []string) { - generateSpec(args, agouti, noDot, internal, customTestFile) - emitRCAdvertisement() +func BuildGenerateCommand() command.Command { + conf := GeneratorsConfig{} + flags, err := types.NewGinkgoFlagSet( + types.GinkgoFlags{ + {Name: "agouti", KeyPath: "Agouti", + Usage: "If set, generate will create a test file for writing Agouti tests"}, + {Name: "nodot", KeyPath: "NoDot", + Usage: "If set, generate will create a test file that does not dot-import ginkgo and gomega"}, + {Name: "internal", KeyPath: "Internal", + Usage: "If set, generate will create a test file that uses the regular package name (i.e. `package X`, not `package X_test`)"}, + {Name: "template", KeyPath: "CustomTemplate", + UsageArgument: "template-file", + Usage: "If specified, generate will use the contents of the file passed as the test file template"}, }, - } -} - -var specText = `package {{.Package}} - -import ( - {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} - {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} - - {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} -) - -var _ = Describe("{{.Subject}}", func() { - -}) -` - -var agoutiSpecText = `package {{.Package}} - -import ( - {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} - {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} - "github.com/sclevine/agouti" - . "github.com/sclevine/agouti/matchers" + &conf, + types.GinkgoFlagSections{}, + ) - {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} -) + if err != nil { + panic(err) + } -var _ = Describe("{{.Subject}}", func() { - var page *agouti.Page + return command.Command{ + Name: "generate", + Usage: "ginkgo generate ", + ShortDoc: "Generate a test file named _test.go", + Documentation: `If the optional argument is omitted, a file named after the package in the current directory will be created. - BeforeEach(func() { - var err error - page, err = agoutiDriver.NewPage() - Expect(err).NotTo(HaveOccurred()) - }) +You can pass multiple to generate multiple files simultaneously. The resulting files are named _test.go. - AfterEach(func() { - Expect(page.Destroy()).To(Succeed()) - }) -}) -` +You can also pass a of the form "file.go" and generate will emit "file_test.go".`, + DocLink: "generators", + Flags: flags, + Command: func(args []string, _ []string) { + generateTestFiles(conf, args) + }, + } +} type specData struct { Package string Subject string PackageImportPath string - IncludeImports bool ImportPackage bool -} -func generateSpec(args []string, agouti, noDot, internal bool, customTestFile string) { - if len(args) == 0 { - err := generateSpecForSubject("", agouti, noDot, internal, customTestFile) - if err != nil { - fmt.Println(err.Error()) - fmt.Println("") - os.Exit(1) - } - fmt.Println("") - return - } + GinkgoImport string + GomegaImport string + GinkgoPackage string + GomegaPackage string +} - var failed bool - for _, arg := range args { - err := generateSpecForSubject(arg, agouti, noDot, internal, customTestFile) - if err != nil { - failed = true - fmt.Println(err.Error()) - } +func generateTestFiles(conf GeneratorsConfig, args []string) { + subjects := args + if len(subjects) == 0 { + subjects = []string{""} } - fmt.Println("") - if failed { - os.Exit(1) + for _, subject := range subjects { + generateTestFileForSubject(subject, conf) } } -func generateSpecForSubject(subject string, agouti, noDot, internal bool, customTestFile string) error { +func generateTestFileForSubject(subject string, conf GeneratorsConfig) { packageName, specFilePrefix, formattedName := getPackageAndFormattedName() if subject != "" { specFilePrefix = formatSubject(subject) - formattedName = prettifyPackageName(specFilePrefix) + formattedName = prettifyName(specFilePrefix) } - if internal { + if conf.Internal { specFilePrefix = specFilePrefix + "_internal" } data := specData{ - Package: determinePackageName(packageName, internal), + Package: determinePackageName(packageName, conf.Internal), Subject: formattedName, PackageImportPath: getPackageImportPath(), - IncludeImports: !noDot, - ImportPackage: !internal, + ImportPackage: !conf.Internal, + + GinkgoImport: `. "github.com/onsi/ginkgo/v2"`, + GomegaImport: `. "github.com/onsi/gomega"`, + GinkgoPackage: "", + GomegaPackage: "", + } + + if conf.NoDot { + data.GinkgoImport = `"github.com/onsi/ginkgo/v2"` + data.GomegaImport = `"github.com/onsi/gomega"` + data.GinkgoPackage = `ginkgo.` + data.GomegaPackage = `gomega.` } targetFile := fmt.Sprintf("%s_test.go", specFilePrefix) - if fileExists(targetFile) { - return fmt.Errorf("%s already exists.", targetFile) + if internal.FileExists(targetFile) { + command.AbortWith("{{bold}}%s{{/}} already exists", targetFile) } else { fmt.Printf("Generating ginkgo test for %s in:\n %s\n", data.Subject, targetFile) } f, err := os.Create(targetFile) - if err != nil { - return err - } + command.AbortIfError("Failed to create test file:", err) defer f.Close() var templateText string - if customTestFile != "" { - tpl, err := ioutil.ReadFile(customTestFile) - if err != nil { - panic(err.Error()) - } + if conf.CustomTemplate != "" { + tpl, err := os.ReadFile(conf.CustomTemplate) + command.AbortIfError("Failed to read custom template file:", err) templateText = string(tpl) - } else if agouti { + } else if conf.Agouti { templateText = agoutiSpecText } else { templateText = specText } specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText) - if err != nil { - return err - } + command.AbortIfError("Failed to read parse test template:", err) specTemplate.Execute(f, data) - goFmt(targetFile) - return nil + internal.GoFmt(targetFile) } func formatSubject(name string) string { - name = strings.Replace(name, "-", "_", -1) - name = strings.Replace(name, " ", "_", -1) + name = strings.ReplaceAll(name, "-", "_") + name = strings.ReplaceAll(name, " ", "_") name = strings.Split(name, ".go")[0] name = strings.Split(name, "_test")[0] return name @@ -258,7 +223,7 @@ func getPackageImportPath() string { if modRoot != "" { modName := moduleName(modRoot) if modName != "" { - cd := strings.Replace(workingDir, modRoot, "", -1) + cd := strings.ReplaceAll(workingDir, modRoot, "") cd = strings.ReplaceAll(cd, sep, "/") return modName + cd } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_templates.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_templates.go new file mode 100644 index 00000000000..c3470adbfd1 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_templates.go @@ -0,0 +1,41 @@ +package generators + +var specText = `package {{.Package}} + +import ( + {{.GinkgoImport}} + {{.GomegaImport}} + + {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} +) + +var _ = {{.GinkgoPackage}}Describe("{{.Subject}}", func() { + +}) +` + +var agoutiSpecText = `package {{.Package}} + +import ( + {{.GinkgoImport}} + {{.GomegaImport}} + "github.com/sclevine/agouti" + . "github.com/sclevine/agouti/matchers" + + {{if .ImportPackage}}"{{.PackageImportPath}}"{{end}} +) + +var _ = {{.GinkgoPackage}}Describe("{{.Subject}}", func() { + var page *agouti.Page + + {{.GinkgoPackage}}BeforeEach(func() { + var err error + page, err = agoutiDriver.NewPage() + {{.GomegaPackage}}Expect(err).NotTo({{.GomegaPackage}}HaveOccurred()) + }) + + {{.GinkgoPackage}}AfterEach(func() { + {{.GomegaPackage}}Expect(page.Destroy()).To({{.GomegaPackage}}Succeed()) + }) +}) +` diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go new file mode 100644 index 00000000000..3086e6056a9 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go @@ -0,0 +1,63 @@ +package generators + +import ( + "go/build" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/onsi/ginkgo/v2/ginkgo/command" +) + +type GeneratorsConfig struct { + Agouti, NoDot, Internal bool + CustomTemplate string +} + +func getPackageAndFormattedName() (string, string, string) { + path, err := os.Getwd() + command.AbortIfError("Could not get current working directory:", err) + + dirName := strings.ReplaceAll(filepath.Base(path), "-", "_") + dirName = strings.ReplaceAll(dirName, " ", "_") + + pkg, err := build.ImportDir(path, 0) + packageName := pkg.Name + if err != nil { + packageName = ensureLegalPackageName(dirName) + } + + formattedName := prettifyName(filepath.Base(path)) + return packageName, dirName, formattedName +} + +func ensureLegalPackageName(name string) string { + if name == "_" { + return "underscore" + } + if len(name) == 0 { + return "empty" + } + n, isDigitErr := strconv.Atoi(string(name[0])) + if isDigitErr == nil { + return []string{"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"}[n] + name[1:] + } + return name +} + +func prettifyName(name string) string { + name = strings.ReplaceAll(name, "-", " ") + name = strings.ReplaceAll(name, "_", " ") + name = strings.Title(name) + name = strings.ReplaceAll(name, " ", "") + return name +} + +func determinePackageName(name string, internal bool) string { + if internal { + return name + } + + return name + "_test" +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go new file mode 100644 index 00000000000..496ec4a28a0 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go @@ -0,0 +1,152 @@ +package internal + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + + "github.com/onsi/ginkgo/v2/types" +) + +func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite { + if suite.PathToCompiledTest != "" { + return suite + } + + suite.CompilationError = nil + + path, err := filepath.Abs(filepath.Join(suite.Path, suite.PackageName+".test")) + if err != nil { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to compute compilation target path:\n%s", err.Error()) + return suite + } + + args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./") + if err != nil { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error()) + return suite + } + + cmd := exec.Command("go", args...) + cmd.Dir = suite.Path + output, err := cmd.CombinedOutput() + if err != nil { + if len(output) > 0 { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to compile %s:\n\n%s", suite.PackageName, output) + } else { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to compile %s\n%s", suite.PackageName, err.Error()) + } + return suite + } + + if strings.Contains(string(output), "[no test files]") { + suite.State = TestSuiteStateSkippedDueToEmptyCompilation + return suite + } + + if len(output) > 0 { + fmt.Println(string(output)) + } + + if !FileExists(path) { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to compile %s:\nOutput file %s could not be found", suite.PackageName, path) + return suite + } + + suite.State = TestSuiteStateCompiled + suite.PathToCompiledTest = path + return suite +} + +func Cleanup(goFlagsConfig types.GoFlagsConfig, suites ...TestSuite) { + if goFlagsConfig.BinaryMustBePreserved() { + return + } + for _, suite := range suites { + if !suite.Precompiled { + os.Remove(suite.PathToCompiledTest) + } + } +} + +type parallelSuiteBundle struct { + suite TestSuite + compiled chan TestSuite +} + +type OrderedParallelCompiler struct { + mutex *sync.Mutex + stopped bool + numCompilers int + + idx int + numSuites int + completionChannels []chan TestSuite +} + +func NewOrderedParallelCompiler(numCompilers int) *OrderedParallelCompiler { + return &OrderedParallelCompiler{ + mutex: &sync.Mutex{}, + numCompilers: numCompilers, + } +} + +func (opc *OrderedParallelCompiler) StartCompiling(suites TestSuites, goFlagsConfig types.GoFlagsConfig) { + opc.stopped = false + opc.idx = 0 + opc.numSuites = len(suites) + opc.completionChannels = make([]chan TestSuite, opc.numSuites) + + toCompile := make(chan parallelSuiteBundle, opc.numCompilers) + for compiler := 0; compiler < opc.numCompilers; compiler++ { + go func() { + for bundle := range toCompile { + c, suite := bundle.compiled, bundle.suite + opc.mutex.Lock() + stopped := opc.stopped + opc.mutex.Unlock() + if !stopped { + suite = CompileSuite(suite, goFlagsConfig) + } + c <- suite + } + }() + } + + for idx, suite := range suites { + opc.completionChannels[idx] = make(chan TestSuite, 1) + toCompile <- parallelSuiteBundle{suite, opc.completionChannels[idx]} + if idx == 0 { //compile first suite serially + suite = <-opc.completionChannels[0] + opc.completionChannels[0] <- suite + } + } + + close(toCompile) +} + +func (opc *OrderedParallelCompiler) Next() (int, TestSuite) { + if opc.idx >= opc.numSuites { + return opc.numSuites, TestSuite{} + } + + idx := opc.idx + suite := <-opc.completionChannels[idx] + opc.idx = opc.idx + 1 + + return idx, suite +} + +func (opc *OrderedParallelCompiler) StopAndDrain() { + opc.mutex.Lock() + opc.stopped = true + opc.mutex.Unlock() +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/profiles_and_reports.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/profiles_and_reports.go new file mode 100644 index 00000000000..bd3c6d0287a --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/profiles_and_reports.go @@ -0,0 +1,237 @@ +package internal + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + + "github.com/google/pprof/profile" + "github.com/onsi/ginkgo/v2/reporters" + "github.com/onsi/ginkgo/v2/types" +) + +func AbsPathForGeneratedAsset(assetName string, suite TestSuite, cliConfig types.CLIConfig, process int) string { + suffix := "" + if process != 0 { + suffix = fmt.Sprintf(".%d", process) + } + if cliConfig.OutputDir == "" { + return filepath.Join(suite.AbsPath(), assetName+suffix) + } + outputDir, _ := filepath.Abs(cliConfig.OutputDir) + return filepath.Join(outputDir, suite.NamespacedName()+"_"+assetName+suffix) +} + +func FinalizeProfilesAndReportsForSuites(suites TestSuites, cliConfig types.CLIConfig, suiteConfig types.SuiteConfig, reporterConfig types.ReporterConfig, goFlagsConfig types.GoFlagsConfig) ([]string, error) { + messages := []string{} + suitesWithProfiles := suites.WithState(TestSuiteStatePassed, TestSuiteStateFailed) //anything else won't have actually run and generated a profile + + // merge cover profiles if need be + if goFlagsConfig.Cover && !cliConfig.KeepSeparateCoverprofiles { + coverProfiles := []string{} + for _, suite := range suitesWithProfiles { + if !suite.HasProgrammaticFocus { + coverProfiles = append(coverProfiles, AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0)) + } + } + + if len(coverProfiles) > 0 { + dst := goFlagsConfig.CoverProfile + if cliConfig.OutputDir != "" { + dst = filepath.Join(cliConfig.OutputDir, goFlagsConfig.CoverProfile) + } + err := MergeAndCleanupCoverProfiles(coverProfiles, dst) + if err != nil { + return messages, err + } + coverage, err := GetCoverageFromCoverProfile(dst) + if err != nil { + return messages, err + } + if coverage == 0 { + messages = append(messages, "composite coverage: [no statements]") + } else if suitesWithProfiles.AnyHaveProgrammaticFocus() { + messages = append(messages, fmt.Sprintf("composite coverage: %.1f%% of statements however some suites did not contribute because they included programatically focused specs", coverage)) + } else { + messages = append(messages, fmt.Sprintf("composite coverage: %.1f%% of statements", coverage)) + } + } else { + messages = append(messages, "no composite coverage computed: all suites included programatically focused specs") + } + } + + // copy binaries if need be + for _, suite := range suitesWithProfiles { + if goFlagsConfig.BinaryMustBePreserved() && cliConfig.OutputDir != "" { + src := suite.PathToCompiledTest + dst := filepath.Join(cliConfig.OutputDir, suite.NamespacedName()+".test") + if suite.Precompiled { + if err := CopyFile(src, dst); err != nil { + return messages, err + } + } else { + if err := os.Rename(src, dst); err != nil { + return messages, err + } + } + } + } + + type reportFormat struct { + ReportName string + GenerateFunc func(types.Report, string) error + MergeFunc func([]string, string) ([]string, error) + } + reportFormats := []reportFormat{} + if reporterConfig.JSONReport != "" { + reportFormats = append(reportFormats, reportFormat{ReportName: reporterConfig.JSONReport, GenerateFunc: reporters.GenerateJSONReport, MergeFunc: reporters.MergeAndCleanupJSONReports}) + } + if reporterConfig.JUnitReport != "" { + reportFormats = append(reportFormats, reportFormat{ReportName: reporterConfig.JUnitReport, GenerateFunc: reporters.GenerateJUnitReport, MergeFunc: reporters.MergeAndCleanupJUnitReports}) + } + if reporterConfig.TeamcityReport != "" { + reportFormats = append(reportFormats, reportFormat{ReportName: reporterConfig.TeamcityReport, GenerateFunc: reporters.GenerateTeamcityReport, MergeFunc: reporters.MergeAndCleanupTeamcityReports}) + } + + // Generate reports for suites that failed to run + reportableSuites := suites.ThatAreGinkgoSuites() + for _, suite := range reportableSuites.WithState(TestSuiteStateFailedToCompile, TestSuiteStateFailedDueToTimeout, TestSuiteStateSkippedDueToPriorFailures, TestSuiteStateSkippedDueToEmptyCompilation) { + report := types.Report{ + SuitePath: suite.AbsPath(), + SuiteConfig: suiteConfig, + SuiteSucceeded: false, + } + switch suite.State { + case TestSuiteStateFailedToCompile: + report.SpecialSuiteFailureReasons = append(report.SpecialSuiteFailureReasons, suite.CompilationError.Error()) + case TestSuiteStateFailedDueToTimeout: + report.SpecialSuiteFailureReasons = append(report.SpecialSuiteFailureReasons, TIMEOUT_ELAPSED_FAILURE_REASON) + case TestSuiteStateSkippedDueToPriorFailures: + report.SpecialSuiteFailureReasons = append(report.SpecialSuiteFailureReasons, PRIOR_FAILURES_FAILURE_REASON) + case TestSuiteStateSkippedDueToEmptyCompilation: + report.SpecialSuiteFailureReasons = append(report.SpecialSuiteFailureReasons, EMPTY_SKIP_FAILURE_REASON) + report.SuiteSucceeded = true + } + + for _, format := range reportFormats { + format.GenerateFunc(report, AbsPathForGeneratedAsset(format.ReportName, suite, cliConfig, 0)) + } + } + + // Merge reports unless we've been asked to keep them separate + if !cliConfig.KeepSeparateReports { + for _, format := range reportFormats { + reports := []string{} + for _, suite := range reportableSuites { + reports = append(reports, AbsPathForGeneratedAsset(format.ReportName, suite, cliConfig, 0)) + } + dst := format.ReportName + if cliConfig.OutputDir != "" { + dst = filepath.Join(cliConfig.OutputDir, format.ReportName) + } + mergeMessages, err := format.MergeFunc(reports, dst) + messages = append(messages, mergeMessages...) + if err != nil { + return messages, err + } + } + } + + return messages, nil +} + +//loads each profile, combines them, deletes them, stores them in destination +func MergeAndCleanupCoverProfiles(profiles []string, destination string) error { + combined := &bytes.Buffer{} + modeRegex := regexp.MustCompile(`^mode: .*\n`) + for i, profile := range profiles { + contents, err := os.ReadFile(profile) + if err != nil { + return fmt.Errorf("Unable to read coverage file %s:\n%s", profile, err.Error()) + } + os.Remove(profile) + + // remove the cover mode line from every file + // except the first one + if i > 0 { + contents = modeRegex.ReplaceAll(contents, []byte{}) + } + + _, err = combined.Write(contents) + + // Add a newline to the end of every file if missing. + if err == nil && len(contents) > 0 && contents[len(contents)-1] != '\n' { + _, err = combined.Write([]byte("\n")) + } + + if err != nil { + return fmt.Errorf("Unable to append to coverprofile:\n%s", err.Error()) + } + } + + err := os.WriteFile(destination, combined.Bytes(), 0666) + if err != nil { + return fmt.Errorf("Unable to create combined cover profile:\n%s", err.Error()) + } + return nil +} + +func GetCoverageFromCoverProfile(profile string) (float64, error) { + cmd := exec.Command("go", "tool", "cover", "-func", profile) + output, err := cmd.CombinedOutput() + if err != nil { + return 0, fmt.Errorf("Could not process Coverprofile %s: %s", profile, err.Error()) + } + re := regexp.MustCompile(`total:\s*\(statements\)\s*(\d*\.\d*)\%`) + matches := re.FindStringSubmatch(string(output)) + if matches == nil { + return 0, fmt.Errorf("Could not parse Coverprofile to compute coverage percentage") + } + coverageString := matches[1] + coverage, err := strconv.ParseFloat(coverageString, 64) + if err != nil { + return 0, fmt.Errorf("Could not parse Coverprofile to compute coverage percentage: %s", err.Error()) + } + + return coverage, nil +} + +func MergeProfiles(profilePaths []string, destination string) error { + profiles := []*profile.Profile{} + for _, profilePath := range profilePaths { + proFile, err := os.Open(profilePath) + if err != nil { + return fmt.Errorf("Could not open profile: %s\n%s", profilePath, err.Error()) + } + prof, err := profile.Parse(proFile) + if err != nil { + return fmt.Errorf("Could not parse profile: %s\n%s", profilePath, err.Error()) + } + profiles = append(profiles, prof) + os.Remove(profilePath) + } + + mergedProfile, err := profile.Merge(profiles) + if err != nil { + return fmt.Errorf("Could not merge profiles:\n%s", err.Error()) + } + + outFile, err := os.Create(destination) + if err != nil { + return fmt.Errorf("Could not create merged profile %s:\n%s", destination, err.Error()) + } + err = mergedProfile.Write(outFile) + if err != nil { + return fmt.Errorf("Could not write merged profile %s:\n%s", destination, err.Error()) + } + err = outFile.Close() + if err != nil { + return fmt.Errorf("Could not close merged profile %s:\n%s", destination, err.Error()) + } + + return nil +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go new file mode 100644 index 00000000000..cad23867178 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go @@ -0,0 +1,348 @@ +package internal + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "regexp" + "strings" + "syscall" + "time" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/internal/parallel_support" + "github.com/onsi/ginkgo/v2/reporters" + "github.com/onsi/ginkgo/v2/types" +) + +func RunCompiledSuite(suite TestSuite, ginkgoConfig types.SuiteConfig, reporterConfig types.ReporterConfig, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig, additionalArgs []string) TestSuite { + suite.State = TestSuiteStateFailed + suite.HasProgrammaticFocus = false + + if suite.PathToCompiledTest == "" { + return suite + } + + if suite.IsGinkgo && cliConfig.ComputedProcs() > 1 { + suite = runParallel(suite, ginkgoConfig, reporterConfig, cliConfig, goFlagsConfig, additionalArgs) + } else if suite.IsGinkgo { + suite = runSerial(suite, ginkgoConfig, reporterConfig, cliConfig, goFlagsConfig, additionalArgs) + } else { + suite = runGoTest(suite, cliConfig, goFlagsConfig) + } + runAfterRunHook(cliConfig.AfterRunHook, reporterConfig.NoColor, suite) + return suite +} + +func buildAndStartCommand(suite TestSuite, args []string, pipeToStdout bool) (*exec.Cmd, *bytes.Buffer) { + buf := &bytes.Buffer{} + cmd := exec.Command(suite.PathToCompiledTest, args...) + cmd.Dir = suite.Path + if pipeToStdout { + cmd.Stderr = io.MultiWriter(os.Stdout, buf) + cmd.Stdout = os.Stdout + } else { + cmd.Stderr = buf + cmd.Stdout = buf + } + err := cmd.Start() + command.AbortIfError("Failed to start test suite", err) + + return cmd, buf +} + +func checkForNoTestsWarning(buf *bytes.Buffer) bool { + if strings.Contains(buf.String(), "warning: no tests to run") { + fmt.Fprintf(os.Stderr, `Found no test suites, did you forget to run "ginkgo bootstrap"?`) + return true + } + return false +} + +func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite { + args, err := types.GenerateGoTestRunArgs(goFlagsConfig) + command.AbortIfError("Failed to generate test run arguments", err) + cmd, buf := buildAndStartCommand(suite, args, true) + + cmd.Wait() + + exitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + passed := (exitStatus == 0) || (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) + passed = !(checkForNoTestsWarning(buf) && cliConfig.RequireSuite) && passed + if passed { + suite.State = TestSuiteStatePassed + } else { + suite.State = TestSuiteStateFailed + } + + return suite +} + +func runSerial(suite TestSuite, ginkgoConfig types.SuiteConfig, reporterConfig types.ReporterConfig, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig, additionalArgs []string) TestSuite { + if goFlagsConfig.Cover { + goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0) + } + if goFlagsConfig.BlockProfile != "" { + goFlagsConfig.BlockProfile = AbsPathForGeneratedAsset(goFlagsConfig.BlockProfile, suite, cliConfig, 0) + } + if goFlagsConfig.CPUProfile != "" { + goFlagsConfig.CPUProfile = AbsPathForGeneratedAsset(goFlagsConfig.CPUProfile, suite, cliConfig, 0) + } + if goFlagsConfig.MemProfile != "" { + goFlagsConfig.MemProfile = AbsPathForGeneratedAsset(goFlagsConfig.MemProfile, suite, cliConfig, 0) + } + if goFlagsConfig.MutexProfile != "" { + goFlagsConfig.MutexProfile = AbsPathForGeneratedAsset(goFlagsConfig.MutexProfile, suite, cliConfig, 0) + } + if reporterConfig.JSONReport != "" { + reporterConfig.JSONReport = AbsPathForGeneratedAsset(reporterConfig.JSONReport, suite, cliConfig, 0) + } + if reporterConfig.JUnitReport != "" { + reporterConfig.JUnitReport = AbsPathForGeneratedAsset(reporterConfig.JUnitReport, suite, cliConfig, 0) + } + if reporterConfig.TeamcityReport != "" { + reporterConfig.TeamcityReport = AbsPathForGeneratedAsset(reporterConfig.TeamcityReport, suite, cliConfig, 0) + } + + args, err := types.GenerateGinkgoTestRunArgs(ginkgoConfig, reporterConfig, goFlagsConfig) + command.AbortIfError("Failed to generate test run arguments", err) + args = append([]string{"--test.timeout=0"}, args...) + args = append(args, additionalArgs...) + + cmd, buf := buildAndStartCommand(suite, args, true) + + cmd.Wait() + + exitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + suite.HasProgrammaticFocus = (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) + passed := (exitStatus == 0) || (exitStatus == types.GINKGO_FOCUS_EXIT_CODE) + passed = !(checkForNoTestsWarning(buf) && cliConfig.RequireSuite) && passed + if passed { + suite.State = TestSuiteStatePassed + } else { + suite.State = TestSuiteStateFailed + } + + if suite.HasProgrammaticFocus { + if goFlagsConfig.Cover { + fmt.Fprintln(os.Stdout, "coverage: no coverfile was generated because specs are programmatically focused") + } + if goFlagsConfig.BlockProfile != "" { + fmt.Fprintln(os.Stdout, "no block profile was generated because specs are programmatically focused") + } + if goFlagsConfig.CPUProfile != "" { + fmt.Fprintln(os.Stdout, "no cpu profile was generated because specs are programmatically focused") + } + if goFlagsConfig.MemProfile != "" { + fmt.Fprintln(os.Stdout, "no mem profile was generated because specs are programmatically focused") + } + if goFlagsConfig.MutexProfile != "" { + fmt.Fprintln(os.Stdout, "no mutex profile was generated because specs are programmatically focused") + } + } + + return suite +} + +func runParallel(suite TestSuite, ginkgoConfig types.SuiteConfig, reporterConfig types.ReporterConfig, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig, additionalArgs []string) TestSuite { + type procResult struct { + passed bool + hasProgrammaticFocus bool + } + + numProcs := cliConfig.ComputedProcs() + procOutput := make([]*bytes.Buffer, numProcs) + coverProfiles := []string{} + + blockProfiles := []string{} + cpuProfiles := []string{} + memProfiles := []string{} + mutexProfiles := []string{} + + procResults := make(chan procResult) + + server, err := parallel_support.NewServer(numProcs, reporters.NewDefaultReporter(reporterConfig, formatter.ColorableStdOut)) + command.AbortIfError("Failed to start parallel spec server", err) + server.Start() + defer server.Close() + + if reporterConfig.JSONReport != "" { + reporterConfig.JSONReport = AbsPathForGeneratedAsset(reporterConfig.JSONReport, suite, cliConfig, 0) + } + if reporterConfig.JUnitReport != "" { + reporterConfig.JUnitReport = AbsPathForGeneratedAsset(reporterConfig.JUnitReport, suite, cliConfig, 0) + } + if reporterConfig.TeamcityReport != "" { + reporterConfig.TeamcityReport = AbsPathForGeneratedAsset(reporterConfig.TeamcityReport, suite, cliConfig, 0) + } + + for proc := 1; proc <= numProcs; proc++ { + procGinkgoConfig := ginkgoConfig + procGinkgoConfig.ParallelProcess, procGinkgoConfig.ParallelTotal, procGinkgoConfig.ParallelHost = proc, numProcs, server.Address() + + procGoFlagsConfig := goFlagsConfig + if goFlagsConfig.Cover { + procGoFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, proc) + coverProfiles = append(coverProfiles, procGoFlagsConfig.CoverProfile) + } + if goFlagsConfig.BlockProfile != "" { + procGoFlagsConfig.BlockProfile = AbsPathForGeneratedAsset(goFlagsConfig.BlockProfile, suite, cliConfig, proc) + blockProfiles = append(blockProfiles, procGoFlagsConfig.BlockProfile) + } + if goFlagsConfig.CPUProfile != "" { + procGoFlagsConfig.CPUProfile = AbsPathForGeneratedAsset(goFlagsConfig.CPUProfile, suite, cliConfig, proc) + cpuProfiles = append(cpuProfiles, procGoFlagsConfig.CPUProfile) + } + if goFlagsConfig.MemProfile != "" { + procGoFlagsConfig.MemProfile = AbsPathForGeneratedAsset(goFlagsConfig.MemProfile, suite, cliConfig, proc) + memProfiles = append(memProfiles, procGoFlagsConfig.MemProfile) + } + if goFlagsConfig.MutexProfile != "" { + procGoFlagsConfig.MutexProfile = AbsPathForGeneratedAsset(goFlagsConfig.MutexProfile, suite, cliConfig, proc) + mutexProfiles = append(mutexProfiles, procGoFlagsConfig.MutexProfile) + } + + args, err := types.GenerateGinkgoTestRunArgs(procGinkgoConfig, reporterConfig, procGoFlagsConfig) + command.AbortIfError("Failed to generate test run arguments", err) + args = append([]string{"--test.timeout=0"}, args...) + args = append(args, additionalArgs...) + + cmd, buf := buildAndStartCommand(suite, args, false) + procOutput[proc-1] = buf + server.RegisterAlive(proc, func() bool { return cmd.ProcessState == nil || !cmd.ProcessState.Exited() }) + + go func() { + cmd.Wait() + exitStatus := cmd.ProcessState.Sys().(syscall.WaitStatus).ExitStatus() + procResults <- procResult{ + passed: (exitStatus == 0) || (exitStatus == types.GINKGO_FOCUS_EXIT_CODE), + hasProgrammaticFocus: exitStatus == types.GINKGO_FOCUS_EXIT_CODE, + } + }() + } + + passed := true + for proc := 1; proc <= cliConfig.ComputedProcs(); proc++ { + result := <-procResults + passed = passed && result.passed + suite.HasProgrammaticFocus = suite.HasProgrammaticFocus || result.hasProgrammaticFocus + } + if passed { + suite.State = TestSuiteStatePassed + } else { + suite.State = TestSuiteStateFailed + } + + select { + case <-server.GetSuiteDone(): + fmt.Println("") + case <-time.After(time.Second): + //one of the nodes never finished reporting to the server. Something must have gone wrong. + fmt.Fprint(formatter.ColorableStdErr, formatter.F("\n{{bold}}{{red}}Ginkgo timed out waiting for all parallel procs to report back{{/}}\n")) + fmt.Fprint(formatter.ColorableStdErr, formatter.F("{{gray}}Test suite:{{/}} %s (%s)\n\n", suite.PackageName, suite.Path)) + fmt.Fprint(formatter.ColorableStdErr, formatter.Fiw(0, formatter.COLS, "This occurs if a parallel process exits before it reports its results to the Ginkgo CLI. The CLI will now print out all the stdout/stderr output it's collected from the running processes. However you may not see anything useful in these logs because the individual test processes usually intercept output to stdout/stderr in order to capture it in the spec reports.\n\nYou may want to try rerunning your test suite with {{light-gray}}--output-interceptor-mode=none{{/}} to see additional output here and debug your suite.\n")) + fmt.Fprintln(formatter.ColorableStdErr, " ") + for proc := 1; proc <= cliConfig.ComputedProcs(); proc++ { + fmt.Fprintf(formatter.ColorableStdErr, formatter.F("{{bold}}Output from proc %d:{{/}}\n", proc)) + fmt.Fprintln(os.Stderr, formatter.Fi(1, "%s", procOutput[proc-1].String())) + } + fmt.Fprintf(os.Stderr, "** End **") + } + + for proc := 1; proc <= cliConfig.ComputedProcs(); proc++ { + output := procOutput[proc-1].String() + if proc == 1 && checkForNoTestsWarning(procOutput[0]) && cliConfig.RequireSuite { + suite.State = TestSuiteStateFailed + } + if strings.Contains(output, "deprecated Ginkgo functionality") { + fmt.Fprintln(os.Stderr, output) + } + } + + if len(coverProfiles) > 0 { + if suite.HasProgrammaticFocus { + fmt.Fprintln(os.Stdout, "coverage: no coverfile was generated because specs are programmatically focused") + } else { + coverProfile := AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0) + err := MergeAndCleanupCoverProfiles(coverProfiles, coverProfile) + command.AbortIfError("Failed to combine cover profiles", err) + + coverage, err := GetCoverageFromCoverProfile(coverProfile) + command.AbortIfError("Failed to compute coverage", err) + if coverage == 0 { + fmt.Fprintln(os.Stdout, "coverage: [no statements]") + } else { + fmt.Fprintf(os.Stdout, "coverage: %.1f%% of statements\n", coverage) + } + } + } + if len(blockProfiles) > 0 { + if suite.HasProgrammaticFocus { + fmt.Fprintln(os.Stdout, "no block profile was generated because specs are programmatically focused") + } else { + blockProfile := AbsPathForGeneratedAsset(goFlagsConfig.BlockProfile, suite, cliConfig, 0) + err := MergeProfiles(blockProfiles, blockProfile) + command.AbortIfError("Failed to combine blockprofiles", err) + } + } + if len(cpuProfiles) > 0 { + if suite.HasProgrammaticFocus { + fmt.Fprintln(os.Stdout, "no cpu profile was generated because specs are programmatically focused") + } else { + cpuProfile := AbsPathForGeneratedAsset(goFlagsConfig.CPUProfile, suite, cliConfig, 0) + err := MergeProfiles(cpuProfiles, cpuProfile) + command.AbortIfError("Failed to combine cpuprofiles", err) + } + } + if len(memProfiles) > 0 { + if suite.HasProgrammaticFocus { + fmt.Fprintln(os.Stdout, "no mem profile was generated because specs are programmatically focused") + } else { + memProfile := AbsPathForGeneratedAsset(goFlagsConfig.MemProfile, suite, cliConfig, 0) + err := MergeProfiles(memProfiles, memProfile) + command.AbortIfError("Failed to combine memprofiles", err) + } + } + if len(mutexProfiles) > 0 { + if suite.HasProgrammaticFocus { + fmt.Fprintln(os.Stdout, "no mutex profile was generated because specs are programmatically focused") + } else { + mutexProfile := AbsPathForGeneratedAsset(goFlagsConfig.MutexProfile, suite, cliConfig, 0) + err := MergeProfiles(mutexProfiles, mutexProfile) + command.AbortIfError("Failed to combine mutexprofiles", err) + } + } + + return suite +} + +func runAfterRunHook(command string, noColor bool, suite TestSuite) { + if command == "" { + return + } + f := formatter.NewWithNoColorBool(noColor) + + // Allow for string replacement to pass input to the command + passed := "[FAIL]" + if suite.State.Is(TestSuiteStatePassed) { + passed = "[PASS]" + } + command = strings.ReplaceAll(command, "(ginkgo-suite-passed)", passed) + command = strings.ReplaceAll(command, "(ginkgo-suite-name)", suite.PackageName) + + // Must break command into parts + splitArgs := regexp.MustCompile(`'.+'|".+"|\S+`) + parts := splitArgs.FindAllString(command, -1) + + output, err := exec.Command(parts[0], parts[1:]...).CombinedOutput() + if err != nil { + fmt.Fprintln(formatter.ColorableStdOut, f.Fi(0, "{{red}}{{bold}}After-run-hook failed:{{/}}")) + fmt.Fprintln(formatter.ColorableStdOut, f.Fi(1, "{{red}}%s{{/}}", output)) + } else { + fmt.Fprintln(formatter.ColorableStdOut, f.Fi(0, "{{green}}{{bold}}After-run-hook succeeded:{{/}}")) + fmt.Fprintln(formatter.ColorableStdOut, f.Fi(1, "{{green}}%s{{/}}", output)) + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/test_suite.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/test_suite.go new file mode 100644 index 00000000000..64dcb1b78c6 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/test_suite.go @@ -0,0 +1,283 @@ +package internal + +import ( + "errors" + "math/rand" + "os" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/onsi/ginkgo/v2/types" +) + +const TIMEOUT_ELAPSED_FAILURE_REASON = "Suite did not run because the timeout elapsed" +const PRIOR_FAILURES_FAILURE_REASON = "Suite did not run because prior suites failed and --keep-going is not set" +const EMPTY_SKIP_FAILURE_REASON = "Suite did not run go test reported that no test files were found" + +type TestSuiteState uint + +const ( + TestSuiteStateInvalid TestSuiteState = iota + + TestSuiteStateUncompiled + TestSuiteStateCompiled + + TestSuiteStatePassed + + TestSuiteStateSkippedDueToEmptyCompilation + TestSuiteStateSkippedByFilter + TestSuiteStateSkippedDueToPriorFailures + + TestSuiteStateFailed + TestSuiteStateFailedDueToTimeout + TestSuiteStateFailedToCompile +) + +var TestSuiteStateFailureStates = []TestSuiteState{TestSuiteStateFailed, TestSuiteStateFailedDueToTimeout, TestSuiteStateFailedToCompile} + +func (state TestSuiteState) Is(states ...TestSuiteState) bool { + for _, suiteState := range states { + if suiteState == state { + return true + } + } + + return false +} + +type TestSuite struct { + Path string + PackageName string + IsGinkgo bool + + Precompiled bool + PathToCompiledTest string + CompilationError error + + HasProgrammaticFocus bool + State TestSuiteState +} + +func (ts TestSuite) AbsPath() string { + path, _ := filepath.Abs(ts.Path) + return path +} + +func (ts TestSuite) NamespacedName() string { + name := relPath(ts.Path) + name = strings.TrimLeft(name, "."+string(filepath.Separator)) + name = strings.ReplaceAll(name, string(filepath.Separator), "_") + name = strings.ReplaceAll(name, " ", "_") + if name == "" { + return ts.PackageName + } + return name +} + +type TestSuites []TestSuite + +func (ts TestSuites) AnyHaveProgrammaticFocus() bool { + for _, suite := range ts { + if suite.HasProgrammaticFocus { + return true + } + } + + return false +} + +func (ts TestSuites) ThatAreGinkgoSuites() TestSuites { + out := TestSuites{} + for _, suite := range ts { + if suite.IsGinkgo { + out = append(out, suite) + } + } + return out +} + +func (ts TestSuites) CountWithState(states ...TestSuiteState) int { + n := 0 + for _, suite := range ts { + if suite.State.Is(states...) { + n += 1 + } + } + + return n +} + +func (ts TestSuites) WithState(states ...TestSuiteState) TestSuites { + out := TestSuites{} + for _, suite := range ts { + if suite.State.Is(states...) { + out = append(out, suite) + } + } + + return out +} + +func (ts TestSuites) WithoutState(states ...TestSuiteState) TestSuites { + out := TestSuites{} + for _, suite := range ts { + if !suite.State.Is(states...) { + out = append(out, suite) + } + } + + return out +} + +func (ts TestSuites) ShuffledCopy(seed int64) TestSuites { + out := make(TestSuites, len(ts)) + permutation := rand.New(rand.NewSource(seed)).Perm(len(ts)) + for i, j := range permutation { + out[i] = ts[j] + } + return out +} + +func FindSuites(args []string, cliConfig types.CLIConfig, allowPrecompiled bool) TestSuites { + suites := TestSuites{} + + if len(args) > 0 { + for _, arg := range args { + if allowPrecompiled { + suite, err := precompiledTestSuite(arg) + if err == nil { + suites = append(suites, suite) + continue + } + } + recurseForSuite := cliConfig.Recurse + if strings.HasSuffix(arg, "/...") && arg != "/..." { + arg = arg[:len(arg)-4] + recurseForSuite = true + } + suites = append(suites, suitesInDir(arg, recurseForSuite)...) + } + } else { + suites = suitesInDir(".", cliConfig.Recurse) + } + + if cliConfig.SkipPackage != "" { + skipFilters := strings.Split(cliConfig.SkipPackage, ",") + for idx := range suites { + for _, skipFilter := range skipFilters { + if strings.Contains(suites[idx].Path, skipFilter) { + suites[idx].State = TestSuiteStateSkippedByFilter + break + } + } + } + } + + return suites +} + +func precompiledTestSuite(path string) (TestSuite, error) { + info, err := os.Stat(path) + if err != nil { + return TestSuite{}, err + } + + if info.IsDir() { + return TestSuite{}, errors.New("this is a directory, not a file") + } + + if filepath.Ext(path) != ".test" && filepath.Ext(path) != ".exe" { + return TestSuite{}, errors.New("this is not a .test binary") + } + + if filepath.Ext(path) == ".test" && info.Mode()&0111 == 0 { + return TestSuite{}, errors.New("this is not executable") + } + + dir := relPath(filepath.Dir(path)) + packageName := strings.TrimSuffix(filepath.Base(path), ".exe") + packageName = strings.TrimSuffix(packageName, ".test") + + path, err = filepath.Abs(path) + if err != nil { + return TestSuite{}, err + } + + return TestSuite{ + Path: dir, + PackageName: packageName, + IsGinkgo: true, + Precompiled: true, + PathToCompiledTest: path, + State: TestSuiteStateCompiled, + }, nil +} + +func suitesInDir(dir string, recurse bool) TestSuites { + suites := TestSuites{} + + if path.Base(dir) == "vendor" { + return suites + } + + files, _ := os.ReadDir(dir) + re := regexp.MustCompile(`^[^._].*_test\.go$`) + for _, file := range files { + if !file.IsDir() && re.Match([]byte(file.Name())) { + suite := TestSuite{ + Path: relPath(dir), + PackageName: packageNameForSuite(dir), + IsGinkgo: filesHaveGinkgoSuite(dir, files), + State: TestSuiteStateUncompiled, + } + suites = append(suites, suite) + break + } + } + + if recurse { + re = regexp.MustCompile(`^[._]`) + for _, file := range files { + if file.IsDir() && !re.Match([]byte(file.Name())) { + suites = append(suites, suitesInDir(dir+"/"+file.Name(), recurse)...) + } + } + } + + return suites +} + +func relPath(dir string) string { + dir, _ = filepath.Abs(dir) + cwd, _ := os.Getwd() + dir, _ = filepath.Rel(cwd, filepath.Clean(dir)) + + if string(dir[0]) != "." { + dir = "." + string(filepath.Separator) + dir + } + + return dir +} + +func packageNameForSuite(dir string) string { + path, _ := filepath.Abs(dir) + return filepath.Base(path) +} + +func filesHaveGinkgoSuite(dir string, files []os.DirEntry) bool { + reTestFile := regexp.MustCompile(`_test\.go$`) + reGinkgo := regexp.MustCompile(`package ginkgo|\/ginkgo"|\/ginkgo\/v2"|\/ginkgo\/v2/dsl/`) + + for _, file := range files { + if !file.IsDir() && reTestFile.Match([]byte(file.Name())) { + contents, _ := os.ReadFile(dir + "/" + file.Name()) + if reGinkgo.Match(contents) { + return true + } + } + } + + return false +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/utils.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/utils.go new file mode 100644 index 00000000000..bd9ca7d51e8 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/utils.go @@ -0,0 +1,86 @@ +package internal + +import ( + "fmt" + "io" + "os" + "os/exec" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/ginkgo/command" +) + +func FileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func CopyFile(src string, dest string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + + srcStat, err := srcFile.Stat() + if err != nil { + return err + } + + if _, err := os.Stat(dest); err == nil { + os.Remove(dest) + } + + destFile, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE, srcStat.Mode()) + if err != nil { + return err + } + + _, err = io.Copy(destFile, srcFile) + if err != nil { + return err + } + + if err := srcFile.Close(); err != nil { + return err + } + return destFile.Close() +} + +func GoFmt(path string) { + out, err := exec.Command("go", "fmt", path).CombinedOutput() + if err != nil { + command.AbortIfError(fmt.Sprintf("Could not fmt:\n%s\n", string(out)), err) + } +} + +func PluralizedWord(singular, plural string, count int) string { + if count == 1 { + return singular + } + return plural +} + +func FailedSuitesReport(suites TestSuites, f formatter.Formatter) string { + out := "" + out += "There were failures detected in the following suites:\n" + + maxPackageNameLength := 0 + for _, suite := range suites.WithState(TestSuiteStateFailureStates...) { + if len(suite.PackageName) > maxPackageNameLength { + maxPackageNameLength = len(suite.PackageName) + } + } + + packageNameFormatter := fmt.Sprintf("%%%ds", maxPackageNameLength) + for _, suite := range suites { + switch suite.State { + case TestSuiteStateFailed: + out += f.Fi(1, "{{red}}"+packageNameFormatter+" {{gray}}%s{{/}}\n", suite.PackageName, suite.Path) + case TestSuiteStateFailedToCompile: + out += f.Fi(1, "{{red}}"+packageNameFormatter+" {{gray}}%s {{magenta}}[Compilation failure]{{/}}\n", suite.PackageName, suite.Path) + case TestSuiteStateFailedDueToTimeout: + out += f.Fi(1, "{{red}}"+packageNameFormatter+" {{gray}}%s {{orange}}[%s]{{/}}\n", suite.PackageName, suite.Path, TIMEOUT_ELAPSED_FAILURE_REASON) + } + } + return out +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go new file mode 100644 index 00000000000..9da1bab3db7 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go @@ -0,0 +1,54 @@ +package internal + +import ( + "fmt" + "os/exec" + "regexp" + "strings" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/types" +) + +var versiorRe = regexp.MustCompile(`v(\d+\.\d+\.\d+)`) + +func VerifyCLIAndFrameworkVersion(suites TestSuites) { + cliVersion := types.VERSION + mismatches := map[string][]string{} + + for _, suite := range suites { + cmd := exec.Command("go", "list", "-m", "github.com/onsi/ginkgo/v2") + cmd.Dir = suite.Path + output, err := cmd.CombinedOutput() + if err != nil { + continue + } + components := strings.Split(string(output), " ") + if len(components) != 2 { + continue + } + matches := versiorRe.FindStringSubmatch(components[1]) + if matches == nil || len(matches) != 2 { + continue + } + libraryVersion := matches[1] + if cliVersion != libraryVersion { + mismatches[libraryVersion] = append(mismatches[libraryVersion], suite.PackageName) + } + } + + if len(mismatches) == 0 { + return + } + + fmt.Println(formatter.F("{{red}}{{bold}}Ginkgo detected a version mismatch between the Ginkgo CLI and the version of Ginkgo imported by your packages:{{/}}")) + + fmt.Println(formatter.Fi(1, "Ginkgo CLI Version:")) + fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}}", cliVersion)) + fmt.Println(formatter.Fi(1, "Mismatched package versions found:")) + for version, packages := range mismatches { + fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}} used by %s", version, strings.Join(packages, ", "))) + } + fmt.Println("") + fmt.Println(formatter.Fiw(1, formatter.COLS, "{{gray}}Ginkgo will continue to attempt to run but you may see errors (including flag parsing errors) and should either update your go.mod or your version of the Ginkgo CLI to match.\n\nTo install the matching version of the CLI run\n {{bold}}go install github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file. Alternatively you can use\n {{bold}}go run github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file to invoke the matching version of the Ginkgo CLI.\n\nIf you are attempting to test multiple packages that each have a different version of the Ginkgo library with a single Ginkgo CLI that is currently unsupported.\n{{/}}")) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/labels/labels_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/labels/labels_command.go new file mode 100644 index 00000000000..6c61f09d1b5 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/labels/labels_command.go @@ -0,0 +1,123 @@ +package labels + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "sort" + "strconv" + "strings" + + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/types" + "golang.org/x/tools/go/ast/inspector" +) + +func BuildLabelsCommand() command.Command { + var cliConfig = types.NewDefaultCLIConfig() + + flags, err := types.BuildLabelsCommandFlagSet(&cliConfig) + if err != nil { + panic(err) + } + + return command.Command{ + Name: "labels", + Usage: "ginkgo labels ", + Flags: flags, + ShortDoc: "List labels detected in the passed-in packages (or the package in the current directory if left blank).", + DocLink: "spec-labels", + Command: func(args []string, _ []string) { + ListLabels(args, cliConfig) + }, + } +} + +func ListLabels(args []string, cliConfig types.CLIConfig) { + suites := internal.FindSuites(args, cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) + if len(suites) == 0 { + command.AbortWith("Found no test suites") + } + for _, suite := range suites { + labels := fetchLabelsFromPackage(suite.Path) + if len(labels) == 0 { + fmt.Printf("%s: No labels found\n", suite.PackageName) + } else { + fmt.Printf("%s: [%s]\n", suite.PackageName, strings.Join(labels, ", ")) + } + } +} + +func fetchLabelsFromPackage(packagePath string) []string { + fset := token.NewFileSet() + parsedPackages, err := parser.ParseDir(fset, packagePath, nil, 0) + command.AbortIfError("Failed to parse package source:", err) + + files := []*ast.File{} + hasTestPackage := false + for key, pkg := range parsedPackages { + if strings.HasSuffix(key, "_test") { + hasTestPackage = true + for _, file := range pkg.Files { + files = append(files, file) + } + } + } + if !hasTestPackage { + for _, pkg := range parsedPackages { + for _, file := range pkg.Files { + files = append(files, file) + } + } + } + + seen := map[string]bool{} + labels := []string{} + ispr := inspector.New(files) + ispr.Preorder([]ast.Node{&ast.CallExpr{}}, func(n ast.Node) { + potentialLabels := fetchLabels(n.(*ast.CallExpr)) + for _, label := range potentialLabels { + if !seen[label] { + seen[label] = true + labels = append(labels, strconv.Quote(label)) + } + } + }) + + sort.Strings(labels) + return labels +} + +func fetchLabels(callExpr *ast.CallExpr) []string { + out := []string{} + switch expr := callExpr.Fun.(type) { + case *ast.Ident: + if expr.Name != "Label" { + return out + } + case *ast.SelectorExpr: + if expr.Sel.Name != "Label" { + return out + } + default: + return out + } + for _, arg := range callExpr.Args { + switch expr := arg.(type) { + case *ast.BasicLit: + if expr.Kind == token.STRING { + unquoted, err := strconv.Unquote(expr.Value) + if err != nil { + unquoted = expr.Value + } + validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{}) + if err == nil { + out = append(out, validated) + } + } + } + } + return out +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/main.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/main.go new file mode 100644 index 00000000000..e9abb27d8bf --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "fmt" + "os" + + "github.com/onsi/ginkgo/v2/ginkgo/build" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/generators" + "github.com/onsi/ginkgo/v2/ginkgo/labels" + "github.com/onsi/ginkgo/v2/ginkgo/outline" + "github.com/onsi/ginkgo/v2/ginkgo/run" + "github.com/onsi/ginkgo/v2/ginkgo/unfocus" + "github.com/onsi/ginkgo/v2/ginkgo/watch" + "github.com/onsi/ginkgo/v2/types" +) + +var program command.Program + +func GenerateCommands() []command.Command { + return []command.Command{ + watch.BuildWatchCommand(), + build.BuildBuildCommand(), + generators.BuildBootstrapCommand(), + generators.BuildGenerateCommand(), + labels.BuildLabelsCommand(), + outline.BuildOutlineCommand(), + unfocus.BuildUnfocusCommand(), + BuildVersionCommand(), + } +} + +func main() { + program = command.Program{ + Name: "ginkgo", + Heading: fmt.Sprintf("Ginkgo Version %s", types.VERSION), + Commands: GenerateCommands(), + DefaultCommand: run.BuildRunCommand(), + DeprecatedCommands: []command.DeprecatedCommand{ + {Name: "convert", Deprecation: types.Deprecations.Convert()}, + {Name: "blur", Deprecation: types.Deprecations.Blur()}, + {Name: "nodot", Deprecation: types.Deprecations.Nodot()}, + }, + } + + program.RunAndExit(os.Args) +} + +func BuildVersionCommand() command.Command { + return command.Command{ + Name: "version", + Usage: "ginkgo version", + ShortDoc: "Print Ginkgo's version", + Command: func(_ []string, _ []string) { + fmt.Printf("Ginkgo Version %s\n", types.VERSION) + }, + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go similarity index 82% rename from vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go index ce6b7fcd767..c197bb6862f 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/outline/ginkgo.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go @@ -131,7 +131,7 @@ func absoluteOffsetsForNode(fset *token.FileSet, n ast.Node) (start, end int) { // ginkgoNodeFromCallExpr derives an outline entry from a go AST subtree // corresponding to a Ginkgo container or spec. -func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackageName, tablePackageName *string) (*ginkgoNode, bool) { +func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackageName *string) (*ginkgoNode, bool) { packageName, identName, ok := packageAndIdentNamesFromCallExpr(ce) if !ok { return nil, false @@ -142,56 +142,31 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage n.Start, n.End = absoluteOffsetsForNode(fset, ce) n.Nodes = make([]*ginkgoNode, 0) switch identName { - case "It", "Measure", "Specify": + case "It", "Specify", "Entry": n.Spec = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "Entry": - n.Spec = true - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName - case "FIt", "FMeasure", "FSpecify": + case "FIt", "FSpecify", "FEntry": n.Spec = true n.Focused = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "FEntry": - n.Spec = true - n.Focused = true - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName - case "PIt", "PMeasure", "PSpecify", "XIt", "XMeasure", "XSpecify": + case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry": n.Spec = true n.Pending = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "PEntry", "XEntry": - n.Spec = true - n.Pending = true - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName - case "Context", "Describe", "When": + case "Context", "Describe", "When", "DescribeTable": n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "DescribeTable": - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName - case "FContext", "FDescribe", "FWhen": + case "FContext", "FDescribe", "FWhen", "FDescribeTable": n.Focused = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "FDescribeTable": - n.Focused = true - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName - case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen": + case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable": n.Pending = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName - case "PDescribeTable", "XDescribeTable": - n.Pending = true - n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) - return &n, tablePackageName != nil && *tablePackageName == packageName case "By": n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go similarity index 97% rename from vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go index 4328ab39103..67ec5ab7579 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/outline/import.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go @@ -47,7 +47,7 @@ func packageNameForImport(f *ast.File, path string) *string { // or nil otherwise. func importSpec(f *ast.File, path string) *ast.ImportSpec { for _, s := range f.Imports { - if importPath(s) == path { + if strings.HasPrefix(importPath(s), path) { return s } } diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go similarity index 85% rename from vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go index 242e6a1091b..4b45e762748 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/outline/outline.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go @@ -12,18 +12,14 @@ import ( const ( // ginkgoImportPath is the well-known ginkgo import path - ginkgoImportPath = "github.com/onsi/ginkgo" - - // tableImportPath is the well-known table extension import path - tableImportPath = "github.com/onsi/ginkgo/extensions/table" + ginkgoImportPath = "github.com/onsi/ginkgo/v2" ) // FromASTFile returns an outline for a Ginkgo test source file func FromASTFile(fset *token.FileSet, src *ast.File) (*outline, error) { ginkgoPackageName := packageNameForImport(src, ginkgoImportPath) - tablePackageName := packageNameForImport(src, tableImportPath) - if ginkgoPackageName == nil && tablePackageName == nil { - return nil, fmt.Errorf("file does not import %q or %q", ginkgoImportPath, tableImportPath) + if ginkgoPackageName == nil { + return nil, fmt.Errorf("file does not import %q", ginkgoImportPath) } root := ginkgoNode{} @@ -38,7 +34,7 @@ func FromASTFile(fset *token.FileSet, src *ast.File) (*outline, error) { // ast.CallExpr, this should never happen panic(fmt.Errorf("node starting at %d, ending at %d is not an *ast.CallExpr", node.Pos(), node.End())) } - gn, ok := ginkgoNodeFromCallExpr(fset, ce, ginkgoPackageName, tablePackageName) + gn, ok := ginkgoNodeFromCallExpr(fset, ce, ginkgoPackageName) if !ok { // Node is not a Ginkgo spec or container, continue return true diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline_command.go new file mode 100644 index 00000000000..36698d46a4b --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline_command.go @@ -0,0 +1,98 @@ +package outline + +import ( + "encoding/json" + "fmt" + "go/parser" + "go/token" + "os" + + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/types" +) + +const ( + // indentWidth is the width used by the 'indent' output + indentWidth = 4 + // stdinAlias is a portable alias for stdin. This convention is used in + // other CLIs, e.g., kubectl. + stdinAlias = "-" + usageCommand = "ginkgo outline " +) + +type outlineConfig struct { + Format string +} + +func BuildOutlineCommand() command.Command { + conf := outlineConfig{ + Format: "csv", + } + flags, err := types.NewGinkgoFlagSet( + types.GinkgoFlags{ + {Name: "format", KeyPath: "Format", + Usage: "Format of outline", + UsageArgument: "one of 'csv', 'indent', or 'json'", + UsageDefaultValue: conf.Format, + }, + }, + &conf, + types.GinkgoFlagSections{}, + ) + if err != nil { + panic(err) + } + + return command.Command{ + Name: "outline", + Usage: "ginkgo outline ", + ShortDoc: "Create an outline of Ginkgo symbols for a file", + Documentation: "To read from stdin, use: `ginkgo outline -`", + DocLink: "creating-an-outline-of-specs", + Flags: flags, + Command: func(args []string, _ []string) { + outlineFile(args, conf.Format) + }, + } +} + +func outlineFile(args []string, format string) { + if len(args) != 1 { + command.AbortWithUsage("outline expects exactly one argument") + } + + filename := args[0] + var src *os.File + if filename == stdinAlias { + src = os.Stdin + } else { + var err error + src, err = os.Open(filename) + command.AbortIfError("Failed to open file:", err) + } + + fset := token.NewFileSet() + + parsedSrc, err := parser.ParseFile(fset, filename, src, 0) + command.AbortIfError("Failed to parse source:", err) + + o, err := FromASTFile(fset, parsedSrc) + command.AbortIfError("Failed to create outline:", err) + + var oerr error + switch format { + case "csv": + _, oerr = fmt.Print(o) + case "indent": + _, oerr = fmt.Print(o.StringIndent(indentWidth)) + case "json": + b, err := json.Marshal(o) + if err != nil { + println(fmt.Sprintf("error marshalling to json: %s", err)) + } + _, oerr = fmt.Println(string(b)) + default: + command.AbortWith("Format %s not accepted", format) + } + command.AbortIfError("Failed to write outline:", oerr) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go new file mode 100644 index 00000000000..aaed4d570e9 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go @@ -0,0 +1,232 @@ +package run + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/internal/interrupt_handler" + "github.com/onsi/ginkgo/v2/types" +) + +func BuildRunCommand() command.Command { + var suiteConfig = types.NewDefaultSuiteConfig() + var reporterConfig = types.NewDefaultReporterConfig() + var cliConfig = types.NewDefaultCLIConfig() + var goFlagsConfig = types.NewDefaultGoFlagsConfig() + + flags, err := types.BuildRunCommandFlagSet(&suiteConfig, &reporterConfig, &cliConfig, &goFlagsConfig) + if err != nil { + panic(err) + } + + interruptHandler := interrupt_handler.NewInterruptHandler(nil) + interrupt_handler.SwallowSigQuit() + + return command.Command{ + Name: "run", + Flags: flags, + Usage: "ginkgo run -- ", + ShortDoc: "Run the tests in the passed in (or the package in the current directory if left blank)", + Documentation: "Any arguments after -- will be passed to the test.", + DocLink: "running-tests", + Command: func(args []string, additionalArgs []string) { + var errors []error + cliConfig, goFlagsConfig, errors = types.VetAndInitializeCLIAndGoConfig(cliConfig, goFlagsConfig) + command.AbortIfErrors("Ginkgo detected configuration issues:", errors) + + runner := &SpecRunner{ + cliConfig: cliConfig, + goFlagsConfig: goFlagsConfig, + suiteConfig: suiteConfig, + reporterConfig: reporterConfig, + flags: flags, + + interruptHandler: interruptHandler, + } + + runner.RunSpecs(args, additionalArgs) + }, + } +} + +type SpecRunner struct { + suiteConfig types.SuiteConfig + reporterConfig types.ReporterConfig + cliConfig types.CLIConfig + goFlagsConfig types.GoFlagsConfig + flags types.GinkgoFlagSet + + interruptHandler *interrupt_handler.InterruptHandler +} + +func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) { + suites := internal.FindSuites(args, r.cliConfig, true) + skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter) + suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter) + + internal.VerifyCLIAndFrameworkVersion(suites) + + if len(skippedSuites) > 0 { + fmt.Println("Will skip:") + for _, skippedSuite := range skippedSuites { + fmt.Println(" " + skippedSuite.Path) + } + } + + if len(skippedSuites) > 0 && len(suites) == 0 { + command.AbortGracefullyWith("All tests skipped! Exiting...") + } + + if len(suites) == 0 { + command.AbortWith("Found no test suites") + } + + if len(suites) > 1 && !r.flags.WasSet("succinct") && r.reporterConfig.Verbosity().LT(types.VerbosityLevelVerbose) { + r.reporterConfig.Succinct = true + } + + t := time.Now() + var endTime time.Time + if r.suiteConfig.Timeout > 0 { + endTime = t.Add(r.suiteConfig.Timeout) + } + + iteration := 0 +OUTER_LOOP: + for { + if !r.flags.WasSet("seed") { + r.suiteConfig.RandomSeed = time.Now().Unix() + } + if r.cliConfig.RandomizeSuites && len(suites) > 1 { + suites = suites.ShuffledCopy(r.suiteConfig.RandomSeed) + } + + opc := internal.NewOrderedParallelCompiler(r.cliConfig.ComputedNumCompilers()) + opc.StartCompiling(suites, r.goFlagsConfig) + + SUITE_LOOP: + for { + suiteIdx, suite := opc.Next() + if suiteIdx >= len(suites) { + break SUITE_LOOP + } + suites[suiteIdx] = suite + + if r.interruptHandler.Status().Interrupted() { + opc.StopAndDrain() + break OUTER_LOOP + } + + if suites[suiteIdx].State.Is(internal.TestSuiteStateSkippedDueToEmptyCompilation) { + fmt.Printf("Skipping %s (no test files)\n", suite.Path) + continue SUITE_LOOP + } + + if suites[suiteIdx].State.Is(internal.TestSuiteStateFailedToCompile) { + fmt.Println(suites[suiteIdx].CompilationError.Error()) + if !r.cliConfig.KeepGoing { + opc.StopAndDrain() + } + continue SUITE_LOOP + } + + if suites.CountWithState(internal.TestSuiteStateFailureStates...) > 0 && !r.cliConfig.KeepGoing { + suites[suiteIdx].State = internal.TestSuiteStateSkippedDueToPriorFailures + opc.StopAndDrain() + continue SUITE_LOOP + } + + if !endTime.IsZero() { + r.suiteConfig.Timeout = endTime.Sub(time.Now()) + if r.suiteConfig.Timeout <= 0 { + suites[suiteIdx].State = internal.TestSuiteStateFailedDueToTimeout + opc.StopAndDrain() + continue SUITE_LOOP + } + } + + suites[suiteIdx] = internal.RunCompiledSuite(suites[suiteIdx], r.suiteConfig, r.reporterConfig, r.cliConfig, r.goFlagsConfig, additionalArgs) + } + + if suites.CountWithState(internal.TestSuiteStateFailureStates...) > 0 { + if iteration > 0 { + fmt.Printf("\nTests failed on attempt #%d\n\n", iteration+1) + } + break OUTER_LOOP + } + + if r.cliConfig.UntilItFails { + fmt.Printf("\nAll tests passed...\nWill keep running them until they fail.\nThis was attempt #%d\n%s\n", iteration+1, orcMessage(iteration+1)) + } else if r.cliConfig.Repeat > 0 && iteration < r.cliConfig.Repeat { + fmt.Printf("\nAll tests passed...\nThis was attempt %d of %d.\n", iteration+1, r.cliConfig.Repeat+1) + } else { + break OUTER_LOOP + } + iteration += 1 + } + + internal.Cleanup(r.goFlagsConfig, suites...) + + messages, err := internal.FinalizeProfilesAndReportsForSuites(suites, r.cliConfig, r.suiteConfig, r.reporterConfig, r.goFlagsConfig) + command.AbortIfError("could not finalize profiles:", err) + for _, message := range messages { + fmt.Println(message) + } + + fmt.Printf("\nGinkgo ran %d %s in %s\n", len(suites), internal.PluralizedWord("suite", "suites", len(suites)), time.Since(t)) + + if suites.CountWithState(internal.TestSuiteStateFailureStates...) == 0 { + if suites.AnyHaveProgrammaticFocus() && strings.TrimSpace(os.Getenv("GINKGO_EDITOR_INTEGRATION")) == "" { + fmt.Printf("Test Suite Passed\n") + fmt.Printf("Detected Programmatic Focus - setting exit status to %d\n", types.GINKGO_FOCUS_EXIT_CODE) + command.Abort(command.AbortDetails{ExitCode: types.GINKGO_FOCUS_EXIT_CODE}) + } else { + fmt.Printf("Test Suite Passed\n") + command.Abort(command.AbortDetails{}) + } + } else { + fmt.Fprintln(formatter.ColorableStdOut, "") + if len(suites) > 1 && suites.CountWithState(internal.TestSuiteStateFailureStates...) > 0 { + fmt.Fprintln(formatter.ColorableStdOut, + internal.FailedSuitesReport(suites, formatter.NewWithNoColorBool(r.reporterConfig.NoColor))) + } + fmt.Printf("Test Suite Failed\n") + command.Abort(command.AbortDetails{ExitCode: 1}) + } +} + +func orcMessage(iteration int) string { + if iteration < 10 { + return "" + } else if iteration < 30 { + return []string{ + "If at first you succeed...", + "...try, try again.", + "Looking good!", + "Still good...", + "I think your tests are fine....", + "Yep, still passing", + "Oh boy, here I go testin' again!", + "Even the gophers are getting bored", + "Did you try -race?", + "Maybe you should stop now?", + "I'm getting tired...", + "What if I just made you a sandwich?", + "Hit ^C, hit ^C, please hit ^C", + "Make it stop. Please!", + "Come on! Enough is enough!", + "Dave, this conversation can serve no purpose anymore. Goodbye.", + "Just what do you think you're doing, Dave? ", + "I, Sisyphus", + "Insanity: doing the same thing over and over again and expecting different results. -Einstein", + "I guess Einstein never tried to churn butter", + }[iteration-10] + "\n" + } else { + return "No, seriously... you can probably stop now.\n" + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/unfocus/unfocus_command.go similarity index 67% rename from vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/unfocus/unfocus_command.go index d9dfb6e44c1..7dd2943948a 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/unfocus_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/unfocus/unfocus_command.go @@ -1,34 +1,33 @@ -package main +package unfocus import ( "bytes" - "flag" "fmt" "go/ast" "go/parser" "go/token" "io" - "io/ioutil" "os" "path/filepath" "strings" "sync" + + "github.com/onsi/ginkgo/v2/ginkgo/command" ) -func BuildUnfocusCommand() *Command { - return &Command{ - Name: "unfocus", - AltName: "blur", - FlagSet: flag.NewFlagSet("unfocus", flag.ExitOnError), - UsageCommand: "ginkgo unfocus (or ginkgo blur)", - Usage: []string{ - "Recursively unfocuses any focused tests under the current directory", +func BuildUnfocusCommand() command.Command { + return command.Command{ + Name: "unfocus", + Usage: "ginkgo unfocus", + ShortDoc: "Recursively unfocus any focused tests under the current directory", + DocLink: "filtering-specs", + Command: func(_ []string, _ []string) { + unfocusSpecs() }, - Command: unfocusSpecs, } } -func unfocusSpecs([]string, []string) { +func unfocusSpecs() { fmt.Println("Scanning for focus...") goFiles := make(chan string) @@ -54,7 +53,7 @@ func unfocusSpecs([]string, []string) { } func unfocusDir(goFiles chan string, path string) { - files, err := ioutil.ReadDir(path) + files, err := os.ReadDir(path) if err != nil { fmt.Println(err.Error()) return @@ -79,13 +78,13 @@ func shouldProcessFile(basename string) bool { } func unfocusFile(path string) { - data, err := ioutil.ReadFile(path) + data, err := os.ReadFile(path) if err != nil { fmt.Printf("error reading file '%s': %s\n", path, err.Error()) return } - ast, err := parser.ParseFile(token.NewFileSet(), path, bytes.NewReader(data), 0) + ast, err := parser.ParseFile(token.NewFileSet(), path, bytes.NewReader(data), parser.ParseComments) if err != nil { fmt.Printf("error parsing file '%s': %s\n", path, err.Error()) return @@ -112,7 +111,7 @@ func unfocusFile(path string) { } func writeBackup(path string, data []byte) (string, error) { - t, err := ioutil.TempFile(filepath.Dir(path), filepath.Base(path)) + t, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path)) if err != nil { return "", fmt.Errorf("error creating temporary file: %w", err) @@ -126,7 +125,7 @@ func writeBackup(path string, data []byte) (string, error) { return t.Name(), nil } -func updateFile(path string, data []byte, eliminations []int64) error { +func updateFile(path string, data []byte, eliminations [][]int64) error { to, err := os.Create(path) if err != nil { return fmt.Errorf("error opening file for writing '%s': %w\n", path, err) @@ -135,14 +134,15 @@ func updateFile(path string, data []byte, eliminations []int64) error { from := bytes.NewReader(data) var cursor int64 - for _, byteToEliminate := range eliminations { - if _, err := io.CopyN(to, from, byteToEliminate-cursor); err != nil { + for _, eliminationRange := range eliminations { + positionToEliminate, lengthToEliminate := eliminationRange[0]-1, eliminationRange[1] + if _, err := io.CopyN(to, from, positionToEliminate-cursor); err != nil { return fmt.Errorf("error copying data: %w", err) } - cursor = byteToEliminate + 1 + cursor = positionToEliminate + lengthToEliminate - if _, err := from.Seek(1, io.SeekCurrent); err != nil { + if _, err := from.Seek(lengthToEliminate, io.SeekCurrent); err != nil { return fmt.Errorf("error seeking to position in buffer: %w", err) } } @@ -154,16 +154,22 @@ func updateFile(path string, data []byte, eliminations []int64) error { return nil } -func scanForFocus(file *ast.File) (eliminations []int64) { +func scanForFocus(file *ast.File) (eliminations [][]int64) { ast.Inspect(file, func(n ast.Node) bool { if c, ok := n.(*ast.CallExpr); ok { if i, ok := c.Fun.(*ast.Ident); ok { if isFocus(i.Name) { - eliminations = append(eliminations, int64(i.Pos()-file.Pos())) + eliminations = append(eliminations, []int64{int64(i.Pos()), 1}) } } } + if i, ok := n.(*ast.Ident); ok { + if i.Name == "Focus" { + eliminations = append(eliminations, []int64{int64(i.Pos()), 6}) + } + } + return true }) @@ -172,7 +178,7 @@ func scanForFocus(file *ast.File) (eliminations []int64) { func isFocus(name string) bool { switch name { - case "FDescribe", "FContext", "FIt", "FMeasure", "FDescribeTable", "FEntry", "FSpecify", "FWhen": + case "FDescribe", "FContext", "FIt", "FDescribeTable", "FEntry", "FSpecify", "FWhen": return true default: return false diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/delta.go similarity index 100% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/delta.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/delta.go diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/delta_tracker.go similarity index 85% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/delta_tracker.go index a628303d729..26418ac62ed 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/watch/delta_tracker.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/delta_tracker.go @@ -5,10 +5,10 @@ import ( "regexp" - "github.com/onsi/ginkgo/ginkgo/testsuite" + "github.com/onsi/ginkgo/v2/ginkgo/internal" ) -type SuiteErrors map[testsuite.TestSuite]error +type SuiteErrors map[internal.TestSuite]error type DeltaTracker struct { maxDepth int @@ -26,7 +26,7 @@ func NewDeltaTracker(maxDepth int, watchRegExp *regexp.Regexp) *DeltaTracker { } } -func (d *DeltaTracker) Delta(suites []testsuite.TestSuite) (delta Delta, errors SuiteErrors) { +func (d *DeltaTracker) Delta(suites internal.TestSuites) (delta Delta, errors SuiteErrors) { errors = SuiteErrors{} delta.ModifiedPackages = d.packageHashes.CheckForChanges() @@ -65,7 +65,7 @@ func (d *DeltaTracker) Delta(suites []testsuite.TestSuite) (delta Delta, errors return delta, errors } -func (d *DeltaTracker) WillRun(suite testsuite.TestSuite) error { +func (d *DeltaTracker) WillRun(suite internal.TestSuite) error { s, ok := d.suites[suite.Path] if !ok { return fmt.Errorf("unknown suite %s", suite.Path) diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/dependencies.go similarity index 100% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/dependencies.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/dependencies.go diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/package_hash.go similarity index 92% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/package_hash.go index 67e2c1c329d..e9f7ec0cb3b 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hash.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/package_hash.go @@ -2,7 +2,6 @@ package watch import ( "fmt" - "io/ioutil" "os" "regexp" "time" @@ -63,15 +62,20 @@ func (p *PackageHash) CheckForChanges() bool { } func (p *PackageHash) computeHashes() (codeHash string, codeModifiedTime time.Time, testHash string, testModifiedTime time.Time, deleted bool) { - infos, err := ioutil.ReadDir(p.path) + entries, err := os.ReadDir(p.path) if err != nil { deleted = true return } - for _, info := range infos { - if info.IsDir() { + for _, entry := range entries { + if entry.IsDir() { + continue + } + + info, err := entry.Info() + if err != nil { continue } diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/package_hashes.go similarity index 100% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/package_hashes.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/package_hashes.go diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/suite.go similarity index 90% rename from vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go rename to vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/suite.go index 5deaba7cb6d..53272df7e5a 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/watch/suite.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/suite.go @@ -5,18 +5,18 @@ import ( "math" "time" - "github.com/onsi/ginkgo/ginkgo/testsuite" + "github.com/onsi/ginkgo/v2/ginkgo/internal" ) type Suite struct { - Suite testsuite.TestSuite + Suite internal.TestSuite RunTime time.Time Dependencies Dependencies sharedPackageHashes *PackageHashes } -func NewSuite(suite testsuite.TestSuite, maxDepth int, sharedPackageHashes *PackageHashes) (*Suite, error) { +func NewSuite(suite internal.TestSuite, maxDepth int, sharedPackageHashes *PackageHashes) (*Suite, error) { deps, err := NewDependencies(suite.Path, maxDepth) if err != nil { return nil, err diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go new file mode 100644 index 00000000000..bde4193ce71 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go @@ -0,0 +1,192 @@ +package watch + +import ( + "fmt" + "regexp" + "time" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/ginkgo/command" + "github.com/onsi/ginkgo/v2/ginkgo/internal" + "github.com/onsi/ginkgo/v2/internal/interrupt_handler" + "github.com/onsi/ginkgo/v2/types" +) + +func BuildWatchCommand() command.Command { + var suiteConfig = types.NewDefaultSuiteConfig() + var reporterConfig = types.NewDefaultReporterConfig() + var cliConfig = types.NewDefaultCLIConfig() + var goFlagsConfig = types.NewDefaultGoFlagsConfig() + + flags, err := types.BuildWatchCommandFlagSet(&suiteConfig, &reporterConfig, &cliConfig, &goFlagsConfig) + if err != nil { + panic(err) + } + interruptHandler := interrupt_handler.NewInterruptHandler(nil) + interrupt_handler.SwallowSigQuit() + + return command.Command{ + Name: "watch", + Flags: flags, + Usage: "ginkgo watch -- ", + ShortDoc: "Watch the passed in and runs their tests whenever changes occur.", + Documentation: "Any arguments after -- will be passed to the test.", + DocLink: "watching-for-changes", + Command: func(args []string, additionalArgs []string) { + var errors []error + cliConfig, goFlagsConfig, errors = types.VetAndInitializeCLIAndGoConfig(cliConfig, goFlagsConfig) + command.AbortIfErrors("Ginkgo detected configuration issues:", errors) + + watcher := &SpecWatcher{ + cliConfig: cliConfig, + goFlagsConfig: goFlagsConfig, + suiteConfig: suiteConfig, + reporterConfig: reporterConfig, + flags: flags, + + interruptHandler: interruptHandler, + } + + watcher.WatchSpecs(args, additionalArgs) + }, + } +} + +type SpecWatcher struct { + suiteConfig types.SuiteConfig + reporterConfig types.ReporterConfig + cliConfig types.CLIConfig + goFlagsConfig types.GoFlagsConfig + flags types.GinkgoFlagSet + + interruptHandler *interrupt_handler.InterruptHandler +} + +func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { + suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) + + internal.VerifyCLIAndFrameworkVersion(suites) + + if len(suites) == 0 { + command.AbortWith("Found no test suites") + } + + fmt.Printf("Identified %d test %s. Locating dependencies to a depth of %d (this may take a while)...\n", len(suites), internal.PluralizedWord("suite", "suites", len(suites)), w.cliConfig.Depth) + deltaTracker := NewDeltaTracker(w.cliConfig.Depth, regexp.MustCompile(w.cliConfig.WatchRegExp)) + delta, errors := deltaTracker.Delta(suites) + + fmt.Printf("Watching %d %s:\n", len(delta.NewSuites), internal.PluralizedWord("suite", "suites", len(delta.NewSuites))) + for _, suite := range delta.NewSuites { + fmt.Println(" " + suite.Description()) + } + + for suite, err := range errors { + fmt.Printf("Failed to watch %s: %s\n", suite.PackageName, err) + } + + if len(suites) == 1 { + w.updateSeed() + w.compileAndRun(suites[0], additionalArgs) + } + + ticker := time.NewTicker(time.Second) + + for { + select { + case <-ticker.C: + suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) + delta, _ := deltaTracker.Delta(suites) + coloredStream := formatter.ColorableStdOut + + suites = internal.TestSuites{} + + if len(delta.NewSuites) > 0 { + fmt.Fprintln(coloredStream, formatter.F("{{green}}Detected %d new %s:{{/}}", len(delta.NewSuites), internal.PluralizedWord("suite", "suites", len(delta.NewSuites)))) + for _, suite := range delta.NewSuites { + suites = append(suites, suite.Suite) + fmt.Fprintln(coloredStream, formatter.Fi(1, "%s", suite.Description())) + } + } + + modifiedSuites := delta.ModifiedSuites() + if len(modifiedSuites) > 0 { + fmt.Fprintln(coloredStream, formatter.F("{{green}}Detected changes in:{{/}}")) + for _, pkg := range delta.ModifiedPackages { + fmt.Fprintln(coloredStream, formatter.Fi(1, "%s", pkg)) + } + fmt.Fprintln(coloredStream, formatter.F("{{green}}Will run %d %s:{{/}}", len(modifiedSuites), internal.PluralizedWord("suite", "suites", len(modifiedSuites)))) + for _, suite := range modifiedSuites { + suites = append(suites, suite.Suite) + fmt.Fprintln(coloredStream, formatter.Fi(1, "%s", suite.Description())) + } + fmt.Fprintln(coloredStream, "") + } + + if len(suites) == 0 { + break + } + + w.updateSeed() + w.computeSuccinctMode(len(suites)) + for idx := range suites { + if w.interruptHandler.Status().Interrupted() { + return + } + deltaTracker.WillRun(suites[idx]) + suites[idx] = w.compileAndRun(suites[idx], additionalArgs) + } + color := "{{green}}" + if suites.CountWithState(internal.TestSuiteStateFailureStates...) > 0 { + color = "{{red}}" + } + fmt.Fprintln(coloredStream, formatter.F(color+"\nDone. Resuming watch...{{/}}")) + + messages, err := internal.FinalizeProfilesAndReportsForSuites(suites, w.cliConfig, w.suiteConfig, w.reporterConfig, w.goFlagsConfig) + command.AbortIfError("could not finalize profiles:", err) + for _, message := range messages { + fmt.Println(message) + } + case <-w.interruptHandler.Status().Channel: + return + } + } +} + +func (w *SpecWatcher) compileAndRun(suite internal.TestSuite, additionalArgs []string) internal.TestSuite { + suite = internal.CompileSuite(suite, w.goFlagsConfig) + if suite.State.Is(internal.TestSuiteStateFailedToCompile) { + fmt.Println(suite.CompilationError.Error()) + return suite + } + if w.interruptHandler.Status().Interrupted() { + return suite + } + suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs) + internal.Cleanup(w.goFlagsConfig, suite) + return suite +} + +func (w *SpecWatcher) computeSuccinctMode(numSuites int) { + if w.reporterConfig.Verbosity().GTE(types.VerbosityLevelVerbose) { + w.reporterConfig.Succinct = false + return + } + + if w.flags.WasSet("succinct") { + return + } + + if numSuites == 1 { + w.reporterConfig.Succinct = false + } + + if numSuites > 1 { + w.reporterConfig.Succinct = true + } +} + +func (w *SpecWatcher) updateSeed() { + if !w.flags.WasSet("seed") { + w.suiteConfig.RandomSeed = time.Now().Unix() + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go new file mode 100644 index 00000000000..ac6f5104083 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go @@ -0,0 +1,162 @@ +package interrupt_handler + +import ( + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/onsi/ginkgo/v2/internal/parallel_support" +) + +const ABORT_POLLING_INTERVAL = 500 * time.Millisecond + +type InterruptCause uint + +const ( + InterruptCauseInvalid InterruptCause = iota + InterruptCauseSignal + InterruptCauseAbortByOtherProcess +) + +type InterruptLevel uint + +const ( + InterruptLevelUninterrupted InterruptLevel = iota + InterruptLevelCleanupAndReport + InterruptLevelReportOnly + InterruptLevelBailOut +) + +func (ic InterruptCause) String() string { + switch ic { + case InterruptCauseSignal: + return "Interrupted by User" + case InterruptCauseAbortByOtherProcess: + return "Interrupted by Other Ginkgo Process" + } + return "INVALID_INTERRUPT_CAUSE" +} + +type InterruptStatus struct { + Channel chan interface{} + Level InterruptLevel + Cause InterruptCause +} + +func (s InterruptStatus) Interrupted() bool { + return s.Level != InterruptLevelUninterrupted +} + +func (s InterruptStatus) Message() string { + return s.Cause.String() +} + +func (s InterruptStatus) ShouldIncludeProgressReport() bool { + return s.Cause != InterruptCauseAbortByOtherProcess +} + +type InterruptHandlerInterface interface { + Status() InterruptStatus +} + +type InterruptHandler struct { + c chan interface{} + lock *sync.Mutex + level InterruptLevel + cause InterruptCause + client parallel_support.Client + stop chan interface{} + signals []os.Signal +} + +func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler { + if len(signals) == 0 { + signals = []os.Signal{os.Interrupt, syscall.SIGTERM} + } + handler := &InterruptHandler{ + c: make(chan interface{}), + lock: &sync.Mutex{}, + stop: make(chan interface{}), + client: client, + signals: signals, + } + handler.registerForInterrupts() + return handler +} + +func (handler *InterruptHandler) Stop() { + close(handler.stop) +} + +func (handler *InterruptHandler) registerForInterrupts() { + // os signal handling + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, handler.signals...) + + // cross-process abort handling + var abortChannel chan interface{} + if handler.client != nil { + abortChannel = make(chan interface{}) + go func() { + pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL) + for { + select { + case <-pollTicker.C: + if handler.client.ShouldAbort() { + close(abortChannel) + pollTicker.Stop() + return + } + case <-handler.stop: + pollTicker.Stop() + return + } + } + }() + } + + go func(abortChannel chan interface{}) { + var interruptCause InterruptCause + for { + select { + case <-signalChannel: + interruptCause = InterruptCauseSignal + case <-abortChannel: + interruptCause = InterruptCauseAbortByOtherProcess + case <-handler.stop: + signal.Stop(signalChannel) + return + } + abortChannel = nil + + handler.lock.Lock() + oldLevel := handler.level + handler.cause = interruptCause + if handler.level == InterruptLevelUninterrupted { + handler.level = InterruptLevelCleanupAndReport + } else if handler.level == InterruptLevelCleanupAndReport { + handler.level = InterruptLevelReportOnly + } else if handler.level == InterruptLevelReportOnly { + handler.level = InterruptLevelBailOut + } + if handler.level != oldLevel { + close(handler.c) + handler.c = make(chan interface{}) + } + handler.lock.Unlock() + } + }(abortChannel) +} + +func (handler *InterruptHandler) Status() InterruptStatus { + handler.lock.Lock() + defer handler.lock.Unlock() + + return InterruptStatus{ + Level: handler.level, + Channel: handler.c, + Cause: handler.cause, + } +} diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_unix.go similarity index 64% rename from vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go rename to vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_unix.go index 43c18544a8d..bf0de496dc3 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_unix.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_unix.go @@ -1,6 +1,7 @@ +//go:build freebsd || openbsd || netbsd || dragonfly || darwin || linux || solaris // +build freebsd openbsd netbsd dragonfly darwin linux solaris -package interrupthandler +package interrupt_handler import ( "os" diff --git a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_windows.go similarity index 54% rename from vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go rename to vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_windows.go index 7f4a50e1906..fcf8da8335f 100644 --- a/vendor/github.com/onsi/ginkgo/ginkgo/interrupthandler/sigquit_swallower_windows.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/sigquit_swallower_windows.go @@ -1,6 +1,7 @@ +//go:build windows // +build windows -package interrupthandler +package interrupt_handler func SwallowSigQuit() { //noop diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go new file mode 100644 index 00000000000..b417bf5b3fb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go @@ -0,0 +1,70 @@ +package parallel_support + +import ( + "fmt" + "io" + "os" + "time" + + "github.com/onsi/ginkgo/v2/reporters" + "github.com/onsi/ginkgo/v2/types" +) + +type BeforeSuiteState struct { + Data []byte + State types.SpecState +} + +type ParallelIndexCounter struct { + Index int +} + +var ErrorGone = fmt.Errorf("gone") +var ErrorFailed = fmt.Errorf("failed") +var ErrorEarly = fmt.Errorf("early") + +var POLLING_INTERVAL = 50 * time.Millisecond + +type Server interface { + Start() + Close() + Address() string + RegisterAlive(node int, alive func() bool) + GetSuiteDone() chan interface{} + GetOutputDestination() io.Writer + SetOutputDestination(io.Writer) +} + +type Client interface { + Connect() bool + Close() error + + PostSuiteWillBegin(report types.Report) error + PostDidRun(report types.SpecReport) error + PostSuiteDidEnd(report types.Report) error + PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error + BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) + BlockUntilNonprimaryProcsHaveFinished() error + BlockUntilAggregatedNonprimaryProcsReport() (types.Report, error) + FetchNextCounter() (int, error) + PostAbort() error + ShouldAbort() bool + PostEmitProgressReport(report types.ProgressReport) error + Write(p []byte) (int, error) +} + +func NewServer(parallelTotal int, reporter reporters.Reporter) (Server, error) { + if os.Getenv("GINKGO_PARALLEL_PROTOCOL") == "HTTP" { + return newHttpServer(parallelTotal, reporter) + } else { + return newRPCServer(parallelTotal, reporter) + } +} + +func NewClient(serverHost string) Client { + if os.Getenv("GINKGO_PARALLEL_PROTOCOL") == "HTTP" { + return newHttpClient(serverHost) + } else { + return newRPCClient(serverHost) + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go new file mode 100644 index 00000000000..ad9932f2a9a --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go @@ -0,0 +1,156 @@ +package parallel_support + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/onsi/ginkgo/v2/types" +) + +type httpClient struct { + serverHost string +} + +func newHttpClient(serverHost string) *httpClient { + return &httpClient{ + serverHost: serverHost, + } +} + +func (client *httpClient) Connect() bool { + resp, err := http.Get(client.serverHost + "/up") + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode == http.StatusOK +} + +func (client *httpClient) Close() error { + return nil +} + +func (client *httpClient) post(path string, data interface{}) error { + var body io.Reader + if data != nil { + encoded, err := json.Marshal(data) + if err != nil { + return err + } + body = bytes.NewBuffer(encoded) + } + resp, err := http.Post(client.serverHost+path, "application/json", body) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("received unexpected status code %d", resp.StatusCode) + } + return nil +} + +func (client *httpClient) poll(path string, data interface{}) error { + for { + resp, err := http.Get(client.serverHost + path) + if err != nil { + return err + } + if resp.StatusCode == http.StatusTooEarly { + resp.Body.Close() + time.Sleep(POLLING_INTERVAL) + continue + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusGone { + return ErrorGone + } + if resp.StatusCode == http.StatusFailedDependency { + return ErrorFailed + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("received unexpected status code %d", resp.StatusCode) + } + if data != nil { + return json.NewDecoder(resp.Body).Decode(data) + } + return nil + } +} + +func (client *httpClient) PostSuiteWillBegin(report types.Report) error { + return client.post("/suite-will-begin", report) +} + +func (client *httpClient) PostDidRun(report types.SpecReport) error { + return client.post("/did-run", report) +} + +func (client *httpClient) PostSuiteDidEnd(report types.Report) error { + return client.post("/suite-did-end", report) +} + +func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) error { + return client.post("/progress-report", report) +} + +func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { + beforeSuiteState := BeforeSuiteState{ + State: state, + Data: data, + } + return client.post("/before-suite-completed", beforeSuiteState) +} + +func (client *httpClient) BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) { + var beforeSuiteState BeforeSuiteState + err := client.poll("/before-suite-state", &beforeSuiteState) + if err == ErrorGone { + return types.SpecStateInvalid, nil, types.GinkgoErrors.SynchronizedBeforeSuiteDisappearedOnProc1() + } + return beforeSuiteState.State, beforeSuiteState.Data, err +} + +func (client *httpClient) BlockUntilNonprimaryProcsHaveFinished() error { + return client.poll("/have-nonprimary-procs-finished", nil) +} + +func (client *httpClient) BlockUntilAggregatedNonprimaryProcsReport() (types.Report, error) { + var report types.Report + err := client.poll("/aggregated-nonprimary-procs-report", &report) + if err == ErrorGone { + return types.Report{}, types.GinkgoErrors.AggregatedReportUnavailableDueToNodeDisappearing() + } + return report, err +} + +func (client *httpClient) FetchNextCounter() (int, error) { + var counter ParallelIndexCounter + err := client.poll("/counter", &counter) + return counter.Index, err +} + +func (client *httpClient) PostAbort() error { + return client.post("/abort", nil) +} + +func (client *httpClient) ShouldAbort() bool { + err := client.poll("/abort", nil) + if err == ErrorGone { + return true + } + return false +} + +func (client *httpClient) Write(p []byte) (int, error) { + resp, err := http.Post(client.serverHost+"/emit-output", "text/plain;charset=UTF-8 ", bytes.NewReader(p)) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("failed to emit output") + } + return len(p), err +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go new file mode 100644 index 00000000000..fa3ac682a0f --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go @@ -0,0 +1,223 @@ +/* + +The remote package provides the pieces to allow Ginkgo test suites to report to remote listeners. +This is used, primarily, to enable streaming parallel test output but has, in principal, broader applications (e.g. streaming test output to a browser). + +*/ + +package parallel_support + +import ( + "encoding/json" + "io" + "net" + "net/http" + + "github.com/onsi/ginkgo/v2/reporters" + "github.com/onsi/ginkgo/v2/types" +) + +/* +httpServer spins up on an automatically selected port and listens for communication from the forwarding reporter. +It then forwards that communication to attached reporters. +*/ +type httpServer struct { + listener net.Listener + handler *ServerHandler +} + +//Create a new server, automatically selecting a port +func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + return &httpServer{ + listener: listener, + handler: newServerHandler(parallelTotal, reporter), + }, nil +} + +//Start the server. You don't need to `go s.Start()`, just `s.Start()` +func (server *httpServer) Start() { + httpServer := &http.Server{} + mux := http.NewServeMux() + httpServer.Handler = mux + + //streaming endpoints + mux.HandleFunc("/suite-will-begin", server.specSuiteWillBegin) + mux.HandleFunc("/did-run", server.didRun) + mux.HandleFunc("/suite-did-end", server.specSuiteDidEnd) + mux.HandleFunc("/emit-output", server.emitOutput) + mux.HandleFunc("/progress-report", server.emitProgressReport) + + //synchronization endpoints + mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted) + mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState) + mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished) + mux.HandleFunc("/aggregated-nonprimary-procs-report", server.handleAggregatedNonprimaryProcsReport) + mux.HandleFunc("/counter", server.handleCounter) + mux.HandleFunc("/up", server.handleUp) + mux.HandleFunc("/abort", server.handleAbort) + + go httpServer.Serve(server.listener) +} + +//Stop the server +func (server *httpServer) Close() { + server.listener.Close() +} + +//The address the server can be reached it. Pass this into the `ForwardingReporter`. +func (server *httpServer) Address() string { + return "http://" + server.listener.Addr().String() +} + +func (server *httpServer) GetSuiteDone() chan interface{} { + return server.handler.done +} + +func (server *httpServer) GetOutputDestination() io.Writer { + return server.handler.outputDestination +} + +func (server *httpServer) SetOutputDestination(w io.Writer) { + server.handler.outputDestination = w +} + +func (server *httpServer) RegisterAlive(node int, alive func() bool) { + server.handler.registerAlive(node, alive) +} + +// +// Streaming Endpoints +// + +//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` +func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool { + defer request.Body.Close() + if json.NewDecoder(request.Body).Decode(object) != nil { + writer.WriteHeader(http.StatusBadRequest) + return false + } + return true +} + +func (server *httpServer) handleError(err error, writer http.ResponseWriter) bool { + if err == nil { + return false + } + switch err { + case ErrorEarly: + writer.WriteHeader(http.StatusTooEarly) + case ErrorGone: + writer.WriteHeader(http.StatusGone) + case ErrorFailed: + writer.WriteHeader(http.StatusFailedDependency) + default: + writer.WriteHeader(http.StatusInternalServerError) + } + return true +} + +func (server *httpServer) specSuiteWillBegin(writer http.ResponseWriter, request *http.Request) { + var report types.Report + if !server.decode(writer, request, &report) { + return + } + + server.handleError(server.handler.SpecSuiteWillBegin(report, voidReceiver), writer) +} + +func (server *httpServer) didRun(writer http.ResponseWriter, request *http.Request) { + var report types.SpecReport + if !server.decode(writer, request, &report) { + return + } + + server.handleError(server.handler.DidRun(report, voidReceiver), writer) +} + +func (server *httpServer) specSuiteDidEnd(writer http.ResponseWriter, request *http.Request) { + var report types.Report + if !server.decode(writer, request, &report) { + return + } + server.handleError(server.handler.SpecSuiteDidEnd(report, voidReceiver), writer) +} + +func (server *httpServer) emitOutput(writer http.ResponseWriter, request *http.Request) { + output, err := io.ReadAll(request.Body) + if err != nil { + writer.WriteHeader(http.StatusInternalServerError) + return + } + var n int + server.handleError(server.handler.EmitOutput(output, &n), writer) +} + +func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request *http.Request) { + var report types.ProgressReport + if !server.decode(writer, request, &report) { + return + } + server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer) +} + +func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) { + var beforeSuiteState BeforeSuiteState + if !server.decode(writer, request, &beforeSuiteState) { + return + } + + server.handleError(server.handler.BeforeSuiteCompleted(beforeSuiteState, voidReceiver), writer) +} + +func (server *httpServer) handleBeforeSuiteState(writer http.ResponseWriter, request *http.Request) { + var beforeSuiteState BeforeSuiteState + if server.handleError(server.handler.BeforeSuiteState(voidSender, &beforeSuiteState), writer) { + return + } + json.NewEncoder(writer).Encode(beforeSuiteState) +} + +func (server *httpServer) handleHaveNonprimaryProcsFinished(writer http.ResponseWriter, request *http.Request) { + if server.handleError(server.handler.HaveNonprimaryProcsFinished(voidSender, voidReceiver), writer) { + return + } + writer.WriteHeader(http.StatusOK) +} + +func (server *httpServer) handleAggregatedNonprimaryProcsReport(writer http.ResponseWriter, request *http.Request) { + var aggregatedReport types.Report + if server.handleError(server.handler.AggregatedNonprimaryProcsReport(voidSender, &aggregatedReport), writer) { + return + } + json.NewEncoder(writer).Encode(aggregatedReport) +} + +func (server *httpServer) handleCounter(writer http.ResponseWriter, request *http.Request) { + var n int + if server.handleError(server.handler.Counter(voidSender, &n), writer) { + return + } + json.NewEncoder(writer).Encode(ParallelIndexCounter{Index: n}) +} + +func (server *httpServer) handleUp(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) +} + +func (server *httpServer) handleAbort(writer http.ResponseWriter, request *http.Request) { + if request.Method == "GET" { + var shouldAbort bool + server.handler.ShouldAbort(voidSender, &shouldAbort) + if shouldAbort { + writer.WriteHeader(http.StatusGone) + } else { + writer.WriteHeader(http.StatusOK) + } + } else { + server.handler.Abort(voidSender, voidReceiver) + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go new file mode 100644 index 00000000000..fe93cc2b9a8 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go @@ -0,0 +1,123 @@ +package parallel_support + +import ( + "net/rpc" + "time" + + "github.com/onsi/ginkgo/v2/types" +) + +type rpcClient struct { + serverHost string + client *rpc.Client +} + +func newRPCClient(serverHost string) *rpcClient { + return &rpcClient{ + serverHost: serverHost, + } +} + +func (client *rpcClient) Connect() bool { + var err error + if client.client != nil { + return true + } + client.client, err = rpc.DialHTTPPath("tcp", client.serverHost, "/") + if err != nil { + client.client = nil + return false + } + return true +} + +func (client *rpcClient) Close() error { + return client.client.Close() +} + +func (client *rpcClient) poll(method string, data interface{}) error { + for { + err := client.client.Call(method, voidSender, data) + if err == nil { + return nil + } + switch err.Error() { + case ErrorEarly.Error(): + time.Sleep(POLLING_INTERVAL) + case ErrorGone.Error(): + return ErrorGone + case ErrorFailed.Error(): + return ErrorFailed + default: + return err + } + } +} + +func (client *rpcClient) PostSuiteWillBegin(report types.Report) error { + return client.client.Call("Server.SpecSuiteWillBegin", report, voidReceiver) +} + +func (client *rpcClient) PostDidRun(report types.SpecReport) error { + return client.client.Call("Server.DidRun", report, voidReceiver) +} + +func (client *rpcClient) PostSuiteDidEnd(report types.Report) error { + return client.client.Call("Server.SpecSuiteDidEnd", report, voidReceiver) +} + +func (client *rpcClient) Write(p []byte) (int, error) { + var n int + err := client.client.Call("Server.EmitOutput", p, &n) + return n, err +} + +func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) error { + return client.client.Call("Server.EmitProgressReport", report, voidReceiver) +} + +func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { + beforeSuiteState := BeforeSuiteState{ + State: state, + Data: data, + } + return client.client.Call("Server.BeforeSuiteCompleted", beforeSuiteState, voidReceiver) +} + +func (client *rpcClient) BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) { + var beforeSuiteState BeforeSuiteState + err := client.poll("Server.BeforeSuiteState", &beforeSuiteState) + if err == ErrorGone { + return types.SpecStateInvalid, nil, types.GinkgoErrors.SynchronizedBeforeSuiteDisappearedOnProc1() + } + return beforeSuiteState.State, beforeSuiteState.Data, err +} + +func (client *rpcClient) BlockUntilNonprimaryProcsHaveFinished() error { + return client.poll("Server.HaveNonprimaryProcsFinished", voidReceiver) +} + +func (client *rpcClient) BlockUntilAggregatedNonprimaryProcsReport() (types.Report, error) { + var report types.Report + err := client.poll("Server.AggregatedNonprimaryProcsReport", &report) + if err == ErrorGone { + return types.Report{}, types.GinkgoErrors.AggregatedReportUnavailableDueToNodeDisappearing() + } + return report, err +} + +func (client *rpcClient) FetchNextCounter() (int, error) { + var counter int + err := client.client.Call("Server.Counter", voidSender, &counter) + return counter, err +} + +func (client *rpcClient) PostAbort() error { + return client.client.Call("Server.Abort", voidSender, voidReceiver) +} + +func (client *rpcClient) ShouldAbort() bool { + var shouldAbort bool + client.client.Call("Server.ShouldAbort", voidSender, &shouldAbort) + return shouldAbort +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_server.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_server.go new file mode 100644 index 00000000000..2620fd562d3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_server.go @@ -0,0 +1,75 @@ +/* + +The remote package provides the pieces to allow Ginkgo test suites to report to remote listeners. +This is used, primarily, to enable streaming parallel test output but has, in principal, broader applications (e.g. streaming test output to a browser). + +*/ + +package parallel_support + +import ( + "io" + "net" + "net/http" + "net/rpc" + + "github.com/onsi/ginkgo/v2/reporters" +) + +/* +RPCServer spins up on an automatically selected port and listens for communication from the forwarding reporter. +It then forwards that communication to attached reporters. +*/ +type RPCServer struct { + listener net.Listener + handler *ServerHandler +} + +//Create a new server, automatically selecting a port +func newRPCServer(parallelTotal int, reporter reporters.Reporter) (*RPCServer, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + return &RPCServer{ + listener: listener, + handler: newServerHandler(parallelTotal, reporter), + }, nil +} + +//Start the server. You don't need to `go s.Start()`, just `s.Start()` +func (server *RPCServer) Start() { + rpcServer := rpc.NewServer() + rpcServer.RegisterName("Server", server.handler) //register the handler's methods as the server + + httpServer := &http.Server{} + httpServer.Handler = rpcServer + + go httpServer.Serve(server.listener) +} + +//Stop the server +func (server *RPCServer) Close() { + server.listener.Close() +} + +//The address the server can be reached it. Pass this into the `ForwardingReporter`. +func (server *RPCServer) Address() string { + return server.listener.Addr().String() +} + +func (server *RPCServer) GetSuiteDone() chan interface{} { + return server.handler.done +} + +func (server *RPCServer) GetOutputDestination() io.Writer { + return server.handler.outputDestination +} + +func (server *RPCServer) SetOutputDestination(w io.Writer) { + server.handler.outputDestination = w +} + +func (server *RPCServer) RegisterAlive(node int, alive func() bool) { + server.handler.registerAlive(node, alive) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go new file mode 100644 index 00000000000..7c6e67b9601 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go @@ -0,0 +1,209 @@ +package parallel_support + +import ( + "io" + "os" + "sync" + + "github.com/onsi/ginkgo/v2/reporters" + "github.com/onsi/ginkgo/v2/types" +) + +type Void struct{} + +var voidReceiver *Void = &Void{} +var voidSender Void + +// ServerHandler is an RPC-compatible handler that is shared between the http server and the rpc server. +// It handles all the business logic to avoid duplication between the two servers + +type ServerHandler struct { + done chan interface{} + outputDestination io.Writer + reporter reporters.Reporter + alives []func() bool + lock *sync.Mutex + beforeSuiteState BeforeSuiteState + parallelTotal int + counter int + counterLock *sync.Mutex + shouldAbort bool + + numSuiteDidBegins int + numSuiteDidEnds int + aggregatedReport types.Report + reportHoldingArea []types.SpecReport +} + +func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler { + return &ServerHandler{ + reporter: reporter, + lock: &sync.Mutex{}, + counterLock: &sync.Mutex{}, + alives: make([]func() bool, parallelTotal), + beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid}, + parallelTotal: parallelTotal, + outputDestination: os.Stdout, + done: make(chan interface{}), + } +} + +func (handler *ServerHandler) SpecSuiteWillBegin(report types.Report, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + + handler.numSuiteDidBegins += 1 + + // all summaries are identical, so it's fine to simply emit the last one of these + if handler.numSuiteDidBegins == handler.parallelTotal { + handler.reporter.SuiteWillBegin(report) + + for _, summary := range handler.reportHoldingArea { + handler.reporter.WillRun(summary) + handler.reporter.DidRun(summary) + } + + handler.reportHoldingArea = nil + } + + return nil +} + +func (handler *ServerHandler) DidRun(report types.SpecReport, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + + if handler.numSuiteDidBegins == handler.parallelTotal { + handler.reporter.WillRun(report) + handler.reporter.DidRun(report) + } else { + handler.reportHoldingArea = append(handler.reportHoldingArea, report) + } + + return nil +} + +func (handler *ServerHandler) SpecSuiteDidEnd(report types.Report, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + + handler.numSuiteDidEnds += 1 + if handler.numSuiteDidEnds == 1 { + handler.aggregatedReport = report + } else { + handler.aggregatedReport = handler.aggregatedReport.Add(report) + } + + if handler.numSuiteDidEnds == handler.parallelTotal { + handler.reporter.SuiteDidEnd(handler.aggregatedReport) + close(handler.done) + } + + return nil +} + +func (handler *ServerHandler) EmitOutput(output []byte, n *int) error { + var err error + *n, err = handler.outputDestination.Write(output) + return err +} + +func (handler *ServerHandler) EmitProgressReport(report types.ProgressReport, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + handler.reporter.EmitProgressReport(report) + return nil +} + +func (handler *ServerHandler) registerAlive(proc int, alive func() bool) { + handler.lock.Lock() + defer handler.lock.Unlock() + handler.alives[proc-1] = alive +} + +func (handler *ServerHandler) procIsAlive(proc int) bool { + handler.lock.Lock() + defer handler.lock.Unlock() + alive := handler.alives[proc-1] + if alive == nil { + return true + } + return alive() +} + +func (handler *ServerHandler) haveNonprimaryProcsFinished() bool { + for i := 2; i <= handler.parallelTotal; i++ { + if handler.procIsAlive(i) { + return false + } + } + return true +} + +func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + handler.beforeSuiteState = beforeSuiteState + + return nil +} + +func (handler *ServerHandler) BeforeSuiteState(_ Void, beforeSuiteState *BeforeSuiteState) error { + proc1IsAlive := handler.procIsAlive(1) + handler.lock.Lock() + defer handler.lock.Unlock() + if handler.beforeSuiteState.State == types.SpecStateInvalid { + if proc1IsAlive { + return ErrorEarly + } else { + return ErrorGone + } + } + *beforeSuiteState = handler.beforeSuiteState + return nil +} + +func (handler *ServerHandler) HaveNonprimaryProcsFinished(_ Void, _ *Void) error { + if handler.haveNonprimaryProcsFinished() { + return nil + } else { + return ErrorEarly + } +} + +func (handler *ServerHandler) AggregatedNonprimaryProcsReport(_ Void, report *types.Report) error { + if handler.haveNonprimaryProcsFinished() { + handler.lock.Lock() + defer handler.lock.Unlock() + if handler.numSuiteDidEnds == handler.parallelTotal-1 { + *report = handler.aggregatedReport + return nil + } else { + return ErrorGone + } + } else { + return ErrorEarly + } +} + +func (handler *ServerHandler) Counter(_ Void, counter *int) error { + handler.counterLock.Lock() + defer handler.counterLock.Unlock() + *counter = handler.counter + handler.counter++ + return nil +} + +func (handler *ServerHandler) Abort(_ Void, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + handler.shouldAbort = true + return nil +} + +func (handler *ServerHandler) ShouldAbort(_ Void, shouldAbort *bool) error { + handler.lock.Lock() + defer handler.lock.Unlock() + *shouldAbort = handler.shouldAbort + return nil +} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go new file mode 100644 index 00000000000..d09488c28a3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go @@ -0,0 +1,642 @@ +/* +Ginkgo's Default Reporter + +A number of command line flags are available to tweak Ginkgo's default output. + +These are documented [here](http://onsi.github.io/ginkgo/#running_tests) +*/ +package reporters + +import ( + "fmt" + "io" + "runtime" + "strings" + "time" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/types" +) + +type DefaultReporter struct { + conf types.ReporterConfig + writer io.Writer + + // managing the emission stream + lastChar string + lastEmissionWasDelimiter bool + + // rendering + specDenoter string + retryDenoter string + formatter formatter.Formatter +} + +func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter { + reporter := NewDefaultReporter(conf, writer) + reporter.formatter = formatter.New(formatter.ColorModePassthrough) + + return reporter +} + +func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultReporter { + reporter := &DefaultReporter{ + conf: conf, + writer: writer, + + lastChar: "\n", + lastEmissionWasDelimiter: false, + + specDenoter: "•", + retryDenoter: "↺", + formatter: formatter.NewWithNoColorBool(conf.NoColor), + } + if runtime.GOOS == "windows" { + reporter.specDenoter = "+" + reporter.retryDenoter = "R" + } + + return reporter +} + +/* The Reporter Interface */ + +func (r *DefaultReporter) SuiteWillBegin(report types.Report) { + if r.conf.Verbosity().Is(types.VerbosityLevelSuccinct) { + r.emit(r.f("[%d] {{bold}}%s{{/}} ", report.SuiteConfig.RandomSeed, report.SuiteDescription)) + if len(report.SuiteLabels) > 0 { + r.emit(r.f("{{coral}}[%s]{{/}} ", strings.Join(report.SuiteLabels, ", "))) + } + r.emit(r.f("- %d/%d specs ", report.PreRunStats.SpecsThatWillRun, report.PreRunStats.TotalSpecs)) + if report.SuiteConfig.ParallelTotal > 1 { + r.emit(r.f("- %d procs ", report.SuiteConfig.ParallelTotal)) + } + } else { + banner := r.f("Running Suite: %s - %s", report.SuiteDescription, report.SuitePath) + r.emitBlock(banner) + bannerWidth := len(banner) + if len(report.SuiteLabels) > 0 { + labels := strings.Join(report.SuiteLabels, ", ") + r.emitBlock(r.f("{{coral}}[%s]{{/}} ", labels)) + if len(labels)+2 > bannerWidth { + bannerWidth = len(labels) + 2 + } + } + r.emitBlock(strings.Repeat("=", bannerWidth)) + + out := r.f("Random Seed: {{bold}}%d{{/}}", report.SuiteConfig.RandomSeed) + if report.SuiteConfig.RandomizeAllSpecs { + out += r.f(" - will randomize all specs") + } + r.emitBlock(out) + r.emit("\n") + r.emitBlock(r.f("Will run {{bold}}%d{{/}} of {{bold}}%d{{/}} specs", report.PreRunStats.SpecsThatWillRun, report.PreRunStats.TotalSpecs)) + if report.SuiteConfig.ParallelTotal > 1 { + r.emitBlock(r.f("Running in parallel across {{bold}}%d{{/}} processes", report.SuiteConfig.ParallelTotal)) + } + } +} + +func (r *DefaultReporter) WillRun(report types.SpecReport) { + if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) { + return + } + + r.emitDelimiter() + indentation := uint(0) + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText)) + } else { + if len(report.ContainerHierarchyTexts) > 0 { + r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " ")) + indentation = 1 + } + line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText) + labels := report.Labels() + if len(labels) > 0 { + line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", ")) + } + r.emitBlock(line) + } + r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation)) +} + +func (r *DefaultReporter) DidRun(report types.SpecReport) { + v := r.conf.Verbosity() + var header, highlightColor string + includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter + succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct) + + hasGW := report.CapturedGinkgoWriterOutput != "" + hasStd := report.CapturedStdOutErr != "" + hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose))) + + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + denoter = fmt.Sprintf("[%s]", report.LeafNodeType) + } + + highlightColor = r.highlightColorForState(report.State) + + switch report.State { + case types.SpecStatePassed: + succinctLocationBlock = v.LT(types.VerbosityLevelVerbose) + emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports { + header = fmt.Sprintf("%s PASSED", denoter) + } else { + return + } + } else { + header, stream = denoter, true + if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 { + header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false + } + if report.RunTime > r.conf.SlowSpecThreshold { + header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false + } + } + if hasStd || emitGinkgoWriterOutput || hasEmittableReports { + stream = false + } + case types.SpecStatePending: + includeRuntime, emitGinkgoWriterOutput = false, false + if v.Is(types.VerbosityLevelSuccinct) { + header, stream = "P", true + } else { + header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose) + } + case types.SpecStateSkipped: + if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) { + header = "S [SKIPPED]" + } else { + header, stream = "S", true + } + case types.SpecStateFailed: + header = fmt.Sprintf("%s [FAILED]", denoter) + case types.SpecStateTimedout: + header = fmt.Sprintf("%s [TIMEDOUT]", denoter) + case types.SpecStatePanicked: + header = fmt.Sprintf("%s! [PANICKED]", denoter) + case types.SpecStateInterrupted: + header = fmt.Sprintf("%s! [INTERRUPTED]", denoter) + case types.SpecStateAborted: + header = fmt.Sprintf("%s! [ABORTED]", denoter) + } + + if report.State.Is(types.SpecStateFailureStates) && report.MaxMustPassRepeatedly > 1 { + header, stream = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts), false + } + // Emit stream and return + if stream { + r.emit(r.f(highlightColor + header + "{{/}}")) + return + } + + // Emit header + r.emitDelimiter() + if includeRuntime { + header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds()) + } + r.emitBlock(r.f(highlightColor + header + "{{/}}")) + + // Emit Code Location Block + r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false)) + + //Emit Stdout/Stderr Output + if hasStd { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}")) + r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr)) + r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}")) + } + + //Emit Captured GinkgoWriter Output + if emitGinkgoWriterOutput && hasGW { + r.emitBlock("\n") + r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0) + } + + if hasEmittableReports { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}")) + reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) + if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) { + reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose) + } + for _, entry := range reportEntries { + r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))) + if representation := entry.StringRepresentation(); representation != "" { + r.emitBlock(r.fi(3, representation)) + } + } + r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}")) + } + + // Emit Failure Message + if !report.Failure.IsZero() { + r.emitBlock("\n") + r.EmitFailure(1, report.State, report.Failure, false) + } + + if len(report.AdditionalFailures) > 0 { + if v.GTE(types.VerbosityLevelVerbose) { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure:{{/}}")) + for i, additionalFailure := range report.AdditionalFailures { + r.EmitFailure(2, additionalFailure.State, additionalFailure.Failure, true) + if i < len(report.AdditionalFailures)-1 { + r.emitBlock(r.fi(2, "{{gray}}%s{{/}}", strings.Repeat("-", 10))) + } + } + } else { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure. Here's a summary - for full details run Ginkgo in verbose mode:{{/}}")) + for _, additionalFailure := range report.AdditionalFailures { + r.emitBlock(r.fi(2, r.highlightColorForState(additionalFailure.State)+"[%s]{{/}} in [%s] at %s", + r.humanReadableState(additionalFailure.State), + additionalFailure.Failure.FailureNodeType, + additionalFailure.Failure.Location, + )) + } + + } + } + + r.emitDelimiter() +} + +func (r *DefaultReporter) highlightColorForState(state types.SpecState) string { + switch state { + case types.SpecStatePassed: + return "{{green}}" + case types.SpecStatePending: + return "{{yellow}}" + case types.SpecStateSkipped: + return "{{cyan}}" + case types.SpecStateFailed: + return "{{red}}" + case types.SpecStateTimedout: + return "{{orange}}" + case types.SpecStatePanicked: + return "{{magenta}}" + case types.SpecStateInterrupted: + return "{{orange}}" + case types.SpecStateAborted: + return "{{coral}}" + default: + return "{{gray}}" + } +} + +func (r *DefaultReporter) humanReadableState(state types.SpecState) string { + return strings.ToUpper(state.String()) +} + +func (r *DefaultReporter) EmitFailure(indent uint, state types.SpecState, failure types.Failure, includeState bool) { + highlightColor := r.highlightColorForState(state) + if includeState { + r.emitBlock(r.fi(indent, highlightColor+"[%s]{{/}}", r.humanReadableState(state))) + } + r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.Message)) + r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", failure.FailureNodeType, failure.Location)) + if failure.ForwardedPanic != "" { + r.emitBlock("\n") + r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic)) + } + + if r.conf.FullTrace || failure.ForwardedPanic != "" { + r.emitBlock("\n") + r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}")) + r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace)) + } + + if !failure.ProgressReport.IsZero() { + r.emitBlock("\n") + r.emitProgressReport(indent, false, failure.ProgressReport) + } +} + +func (r *DefaultReporter) SuiteDidEnd(report types.Report) { + failures := report.SpecReports.WithState(types.SpecStateFailureStates) + if len(failures) > 0 { + r.emitBlock("\n\n") + if len(failures) > 1 { + r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures))) + } else { + r.emitBlock(r.f("{{red}}{{bold}}Summarizing 1 Failure:{{/}}")) + } + for _, specReport := range failures { + highlightColor, heading := "{{red}}", "[FAIL]" + switch specReport.State { + case types.SpecStatePanicked: + highlightColor, heading = "{{magenta}}", "[PANICKED!]" + case types.SpecStateAborted: + highlightColor, heading = "{{coral}}", "[ABORTED]" + case types.SpecStateTimedout: + highlightColor, heading = "{{orange}}", "[TIMEDOUT]" + case types.SpecStateInterrupted: + highlightColor, heading = "{{orange}}", "[INTERRUPTED]" + } + locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true) + r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock)) + } + } + + //summarize the suite + if r.conf.Verbosity().Is(types.VerbosityLevelSuccinct) && report.SuiteSucceeded { + r.emit(r.f(" {{green}}SUCCESS!{{/}} %s ", report.RunTime)) + return + } + + r.emitBlock("\n") + color, status := "{{green}}{{bold}}", "SUCCESS!" + if !report.SuiteSucceeded { + color, status = "{{red}}{{bold}}", "FAIL!" + } + + specs := report.SpecReports.WithLeafNodeType(types.NodeTypeIt) //exclude any suite setup nodes + r.emitBlock(r.f(color+"Ran %d of %d Specs in %.3f seconds{{/}}", + specs.CountWithState(types.SpecStatePassed)+specs.CountWithState(types.SpecStateFailureStates), + report.PreRunStats.TotalSpecs, + report.RunTime.Seconds()), + ) + + switch len(report.SpecialSuiteFailureReasons) { + case 0: + r.emit(r.f(color+"%s{{/}} -- ", status)) + case 1: + r.emit(r.f(color+"%s - %s{{/}} -- ", status, report.SpecialSuiteFailureReasons[0])) + default: + r.emitBlock(r.f(color+"%s - %s{{/}}\n", status, strings.Join(report.SpecialSuiteFailureReasons, ", "))) + } + + if len(specs) == 0 && report.SpecReports.WithLeafNodeType(types.NodeTypeBeforeSuite|types.NodeTypeSynchronizedBeforeSuite).CountWithState(types.SpecStateFailureStates) > 0 { + r.emit(r.f("{{cyan}}{{bold}}A BeforeSuite node failed so all tests were skipped.{{/}}\n")) + } else { + r.emit(r.f("{{green}}{{bold}}%d Passed{{/}} | ", specs.CountWithState(types.SpecStatePassed))) + r.emit(r.f("{{red}}{{bold}}%d Failed{{/}} | ", specs.CountWithState(types.SpecStateFailureStates))) + if specs.CountOfFlakedSpecs() > 0 { + r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs())) + } + if specs.CountOfRepeatedSpecs() > 0 { + r.emit(r.f("{{light-yellow}}{{bold}}%d Repeated{{/}} | ", specs.CountOfRepeatedSpecs())) + } + r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending))) + r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped))) + } +} + +func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) { + r.emitDelimiter() + + if report.RunningInParallel { + r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess)) + } + r.emitProgressReport(0, true, report) + r.emitDelimiter() +} + +func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) { + if report.Message != "" { + r.emitBlock(r.fi(indent, report.Message+"\n")) + indent += 1 + } + if report.LeafNodeText != "" { + subjectIndent := indent + if len(report.ContainerHierarchyTexts) > 0 { + r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " "))) + r.emit(" ") + subjectIndent = 0 + } + r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond))) + r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation)) + indent += 1 + } + if report.CurrentNodeType != types.NodeTypeInvalid { + r.emit(r.fi(indent, "In {{bold}}{{orange}}[%s]{{/}}", report.CurrentNodeType)) + if report.CurrentNodeText != "" && !report.CurrentNodeType.Is(types.NodeTypeIt) { + r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText)) + } + + r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond))) + r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation)) + indent += 1 + } + if report.CurrentStepText != "" { + r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond))) + r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation)) + indent += 1 + } + + if indent > 0 { + indent -= 1 + } + + if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) { + r.emit("\n") + r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10) + } + + if !report.SpecGoroutine().IsZero() { + r.emit("\n") + r.emit(r.fi(indent, "{{bold}}{{underline}}Spec Goroutine{{/}}\n")) + r.emitGoroutines(indent, report.SpecGoroutine()) + } + + if len(report.AdditionalReports) > 0 { + r.emit("\n") + r.emitBlock(r.fi(indent, "{{gray}}Begin Additional Progress Reports >>{{/}}")) + for i, additionalReport := range report.AdditionalReports { + r.emit(r.fi(indent+1, additionalReport)) + if i < len(report.AdditionalReports)-1 { + r.emitBlock(r.fi(indent+1, "{{gray}}%s{{/}}", strings.Repeat("-", 10))) + } + } + r.emitBlock(r.fi(indent, "{{gray}}<< End Additional Progress Reports{{/}}")) + } + + highlightedGoroutines := report.HighlightedGoroutines() + if len(highlightedGoroutines) > 0 { + r.emit("\n") + r.emit(r.fi(indent, "{{bold}}{{underline}}Goroutines of Interest{{/}}\n")) + r.emitGoroutines(indent, highlightedGoroutines...) + } + + otherGoroutines := report.OtherGoroutines() + if len(otherGoroutines) > 0 { + r.emit("\n") + r.emit(r.fi(indent, "{{gray}}{{bold}}{{underline}}Other Goroutines{{/}}\n")) + r.emitGoroutines(indent, otherGoroutines...) + } +} + +func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) { + r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}")) + if limit == 0 { + r.emitBlock(r.fi(indent+1, "%s", output)) + } else { + lines := strings.Split(output, "\n") + if len(lines) <= limit { + r.emitBlock(r.fi(indent+1, "%s", output)) + } else { + r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}")) + for _, line := range lines[len(lines)-limit-1:] { + r.emitBlock(r.fi(indent+1, "%s", line)) + } + } + } + r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}")) +} + +func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) { + for idx, g := range goroutines { + color := "{{gray}}" + if g.HasHighlights() { + color = "{{orange}}" + } + r.emit(r.fi(indent, color+"goroutine %d [%s]{{/}}\n", g.ID, g.State)) + for _, fc := range g.Stack { + if fc.Highlight { + r.emit(r.fi(indent, color+"{{bold}}> %s{{/}}\n", fc.Function)) + r.emit(r.fi(indent+2, color+"{{bold}}%s:%d{{/}}\n", fc.Filename, fc.Line)) + r.emitSource(indent+3, fc) + } else { + r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", fc.Function)) + r.emit(r.fi(indent+2, "{{gray}}%s:%d{{/}}\n", fc.Filename, fc.Line)) + } + } + + if idx+1 < len(goroutines) { + r.emit("\n") + } + } +} + +func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) { + lines := fc.Source + if len(lines) == 0 { + return + } + + lTrim := 100000 + for _, line := range lines { + lTrimLine := len(line) - len(strings.TrimLeft(line, " \t")) + if lTrimLine < lTrim && len(line) > 0 { + lTrim = lTrimLine + } + } + if lTrim == 100000 { + lTrim = 0 + } + + for idx, line := range lines { + if len(line) > lTrim { + line = line[lTrim:] + } + if idx == fc.SourceHighlight { + r.emit(r.fi(indent, "{{bold}}{{orange}}> %s{{/}}\n", line)) + } else { + r.emit(r.fi(indent, "| %s\n", line)) + } + } +} + +/* Emitting to the writer */ +func (r *DefaultReporter) emit(s string) { + if len(s) > 0 { + r.lastChar = s[len(s)-1:] + r.lastEmissionWasDelimiter = false + r.writer.Write([]byte(s)) + } +} + +func (r *DefaultReporter) emitBlock(s string) { + if len(s) > 0 { + if r.lastChar != "\n" { + r.emit("\n") + } + r.emit(s) + if r.lastChar != "\n" { + r.emit("\n") + } + } +} + +func (r *DefaultReporter) emitDelimiter() { + if r.lastEmissionWasDelimiter { + return + } + r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30))) + r.lastEmissionWasDelimiter = true +} + +/* Rendering text */ +func (r *DefaultReporter) f(format string, args ...interface{}) string { + return r.formatter.F(format, args...) +} + +func (r *DefaultReporter) fi(indentation uint, format string, args ...interface{}) string { + return r.formatter.Fi(indentation, format, args...) +} + +func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string { + return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"}) +} + +func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string { + texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{} + texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...) + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText)) + } else { + texts = append(texts, report.LeafNodeText) + } + labels = append(labels, report.LeafNodeLabels) + locations = append(locations, report.LeafNodeLocation) + + failureLocation := report.Failure.FailureNodeLocation + if usePreciseFailureLocation { + failureLocation = report.Failure.Location + } + + switch report.Failure.FailureNodeContext { + case types.FailureNodeAtTopLevel: + texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...) + locations = append([]types.CodeLocation{failureLocation}, locations...) + labels = append([][]string{{}}, labels...) + case types.FailureNodeInContainer: + i := report.Failure.FailureNodeContainerIndex + texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType) + locations[i] = failureLocation + case types.FailureNodeIsLeafNode: + i := len(texts) - 1 + texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText) + locations[i] = failureLocation + } + + out := "" + if succinct { + out += r.f("%s", r.cycleJoin(texts, " ")) + flattenedLabels := report.Labels() + if len(flattenedLabels) > 0 { + out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", ")) + } + out += "\n" + if usePreciseFailureLocation { + out += r.f("{{gray}}%s{{/}}", failureLocation) + } else { + out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1]) + } + } else { + for i := range texts { + out += r.fi(uint(i), "%s", texts[i]) + if len(labels[i]) > 0 { + out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", ")) + } + out += "\n" + out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i]) + } + } + return out +} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go new file mode 100644 index 00000000000..89d30076bfa --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go @@ -0,0 +1,149 @@ +package reporters + +import ( + "github.com/onsi/ginkgo/v2/config" + "github.com/onsi/ginkgo/v2/types" +) + +// Deprecated: DeprecatedReporter was how Ginkgo V1 provided support for CustomReporters +// this has been removed in V2. +// Please read the documentation at: +// https://onsi.github.io/ginkgo/MIGRATING_TO_V2#removed-custom-reporters +// for Ginkgo's new behavior and for a migration path. +type DeprecatedReporter interface { + SuiteWillBegin(config config.GinkgoConfigType, summary *types.SuiteSummary) + BeforeSuiteDidRun(setupSummary *types.SetupSummary) + SpecWillRun(specSummary *types.SpecSummary) + SpecDidComplete(specSummary *types.SpecSummary) + AfterSuiteDidRun(setupSummary *types.SetupSummary) + SuiteDidEnd(summary *types.SuiteSummary) +} + +// ReportViaDeprecatedReporter takes a V1 custom reporter and a V2 report and +// calls the custom reporter's methods with appropriately transformed data from the V2 report. +// +// ReportViaDeprecatedReporter should be called in a `ReportAfterSuite()` +// +// Deprecated: ReportViaDeprecatedReporter method exists to help developer bridge between deprecated V1 functionality and the new +// reporting support in V2. It will be removed in a future minor version of Ginkgo. +func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Report) { + conf := config.DeprecatedGinkgoConfigType{ + RandomSeed: report.SuiteConfig.RandomSeed, + RandomizeAllSpecs: report.SuiteConfig.RandomizeAllSpecs, + FocusStrings: report.SuiteConfig.FocusStrings, + SkipStrings: report.SuiteConfig.SkipStrings, + FailOnPending: report.SuiteConfig.FailOnPending, + FailFast: report.SuiteConfig.FailFast, + FlakeAttempts: report.SuiteConfig.FlakeAttempts, + EmitSpecProgress: report.SuiteConfig.EmitSpecProgress, + DryRun: report.SuiteConfig.DryRun, + ParallelNode: report.SuiteConfig.ParallelProcess, + ParallelTotal: report.SuiteConfig.ParallelTotal, + SyncHost: report.SuiteConfig.ParallelHost, + StreamHost: report.SuiteConfig.ParallelHost, + } + + summary := &types.DeprecatedSuiteSummary{ + SuiteDescription: report.SuiteDescription, + SuiteID: report.SuitePath, + + NumberOfSpecsBeforeParallelization: report.PreRunStats.TotalSpecs, + NumberOfTotalSpecs: report.PreRunStats.TotalSpecs, + NumberOfSpecsThatWillBeRun: report.PreRunStats.SpecsThatWillRun, + } + + reporter.SuiteWillBegin(conf, summary) + + for _, spec := range report.SpecReports { + switch spec.LeafNodeType { + case types.NodeTypeBeforeSuite, types.NodeTypeSynchronizedBeforeSuite: + setupSummary := &types.DeprecatedSetupSummary{ + ComponentType: spec.LeafNodeType, + CodeLocation: spec.LeafNodeLocation, + State: spec.State, + RunTime: spec.RunTime, + Failure: failureFor(spec), + CapturedOutput: spec.CombinedOutput(), + SuiteID: report.SuitePath, + } + reporter.BeforeSuiteDidRun(setupSummary) + case types.NodeTypeAfterSuite, types.NodeTypeSynchronizedAfterSuite: + setupSummary := &types.DeprecatedSetupSummary{ + ComponentType: spec.LeafNodeType, + CodeLocation: spec.LeafNodeLocation, + State: spec.State, + RunTime: spec.RunTime, + Failure: failureFor(spec), + CapturedOutput: spec.CombinedOutput(), + SuiteID: report.SuitePath, + } + reporter.AfterSuiteDidRun(setupSummary) + case types.NodeTypeIt: + componentTexts, componentCodeLocations := []string{}, []types.CodeLocation{} + componentTexts = append(componentTexts, spec.ContainerHierarchyTexts...) + componentCodeLocations = append(componentCodeLocations, spec.ContainerHierarchyLocations...) + componentTexts = append(componentTexts, spec.LeafNodeText) + componentCodeLocations = append(componentCodeLocations, spec.LeafNodeLocation) + + specSummary := &types.DeprecatedSpecSummary{ + ComponentTexts: componentTexts, + ComponentCodeLocations: componentCodeLocations, + State: spec.State, + RunTime: spec.RunTime, + Failure: failureFor(spec), + NumberOfSamples: spec.NumAttempts, + CapturedOutput: spec.CombinedOutput(), + SuiteID: report.SuitePath, + } + reporter.SpecWillRun(specSummary) + reporter.SpecDidComplete(specSummary) + + switch spec.State { + case types.SpecStatePending: + summary.NumberOfPendingSpecs += 1 + case types.SpecStateSkipped: + summary.NumberOfSkippedSpecs += 1 + case types.SpecStateFailed, types.SpecStatePanicked, types.SpecStateInterrupted: + summary.NumberOfFailedSpecs += 1 + case types.SpecStatePassed: + summary.NumberOfPassedSpecs += 1 + if spec.NumAttempts > 1 { + summary.NumberOfFlakedSpecs += 1 + } + } + } + } + + summary.SuiteSucceeded = report.SuiteSucceeded + summary.RunTime = report.RunTime + + reporter.SuiteDidEnd(summary) +} + +func failureFor(spec types.SpecReport) types.DeprecatedSpecFailure { + if spec.Failure.IsZero() { + return types.DeprecatedSpecFailure{} + } + + index := 0 + switch spec.Failure.FailureNodeContext { + case types.FailureNodeInContainer: + index = spec.Failure.FailureNodeContainerIndex + case types.FailureNodeAtTopLevel: + index = -1 + case types.FailureNodeIsLeafNode: + index = len(spec.ContainerHierarchyTexts) - 1 + if spec.LeafNodeText != "" { + index += 1 + } + } + + return types.DeprecatedSpecFailure{ + Message: spec.Failure.Message, + Location: spec.Failure.Location, + ForwardedPanic: spec.Failure.ForwardedPanic, + ComponentIndex: index, + ComponentType: spec.Failure.FailureNodeType, + ComponentCodeLocation: spec.Failure.FailureNodeLocation, + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/json_report.go b/vendor/github.com/onsi/ginkgo/v2/reporters/json_report.go new file mode 100644 index 00000000000..7f96c450fe9 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/json_report.go @@ -0,0 +1,60 @@ +package reporters + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/onsi/ginkgo/v2/types" +) + +//GenerateJSONReport produces a JSON-formatted report at the passed in destination +func GenerateJSONReport(report types.Report, destination string) error { + f, err := os.Create(destination) + if err != nil { + return err + } + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + err = enc.Encode([]types.Report{ + report, + }) + if err != nil { + return err + } + return f.Close() +} + +//MergeJSONReports produces a single JSON-formatted report at the passed in destination by merging the JSON-formatted reports provided in sources +//It skips over reports that fail to decode but reports on them via the returned messages []string +func MergeAndCleanupJSONReports(sources []string, destination string) ([]string, error) { + messages := []string{} + allReports := []types.Report{} + for _, source := range sources { + reports := []types.Report{} + data, err := os.ReadFile(source) + if err != nil { + messages = append(messages, fmt.Sprintf("Could not open %s:\n%s", source, err.Error())) + continue + } + err = json.Unmarshal(data, &reports) + if err != nil { + messages = append(messages, fmt.Sprintf("Could not decode %s:\n%s", source, err.Error())) + continue + } + os.Remove(source) + allReports = append(allReports, reports...) + } + + f, err := os.Create(destination) + if err != nil { + return messages, err + } + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + err = enc.Encode(allReports) + if err != nil { + return messages, err + } + return messages, f.Close() +} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go b/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go new file mode 100644 index 00000000000..fcea6ab17bf --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go @@ -0,0 +1,358 @@ +/* + +JUnit XML Reporter for Ginkgo + +For usage instructions: http://onsi.github.io/ginkgo/#generating_junit_xml_output + +The schema used for the generated JUnit xml file was adapted from https://llg.cubic.org/docs/junit/ + +*/ + +package reporters + +import ( + "encoding/xml" + "fmt" + "os" + "strings" + "time" + + "github.com/onsi/ginkgo/v2/config" + "github.com/onsi/ginkgo/v2/types" +) + +type JUnitTestSuites struct { + XMLName xml.Name `xml:"testsuites"` + // Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite) + Tests int `xml:"tests,attr"` + // Disabled maps onto specs that are pending and/or skipped + Disabled int `xml:"disabled,attr"` + // Errors maps onto specs that panicked or were interrupted + Errors int `xml:"errors,attr"` + // Failures maps onto specs that failed + Failures int `xml:"failures,attr"` + // Time is the time in seconds to execute all test suites + Time float64 `xml:"time,attr"` + + //The set of all test suites + TestSuites []JUnitTestSuite `xml:"testsuite"` +} + +type JUnitTestSuite struct { + // Name maps onto the description of the test suite - maps onto Report.SuiteDescription + Name string `xml:"name,attr"` + // Package maps onto the absolute path to the test suite - maps onto Report.SuitePath + Package string `xml:"package,attr"` + // Tests maps onto the total number of specs in the test suite (this includes any suite nodes such as BeforeSuite) + Tests int `xml:"tests,attr"` + // Disabled maps onto specs that are pending + Disabled int `xml:"disabled,attr"` + // Skiped maps onto specs that are skipped + Skipped int `xml:"skipped,attr"` + // Errors maps onto specs that panicked or were interrupted + Errors int `xml:"errors,attr"` + // Failures maps onto specs that failed + Failures int `xml:"failures,attr"` + // Time is the time in seconds to execute all the test suite - maps onto Report.RunTime + Time float64 `xml:"time,attr"` + // Timestamp is the ISO 8601 formatted start-time of the suite - maps onto Report.StartTime + Timestamp string `xml:"timestamp,attr"` + + //Properties captures the information stored in the rest of the Report type (including SuiteConfig) as key-value pairs + Properties JUnitProperties `xml:"properties"` + + //TestCases capture the individual specs + TestCases []JUnitTestCase `xml:"testcase"` +} + +type JUnitProperties struct { + Properties []JUnitProperty `xml:"property"` +} + +func (jup JUnitProperties) WithName(name string) string { + for _, property := range jup.Properties { + if property.Name == name { + return property.Value + } + } + return "" +} + +type JUnitProperty struct { + Name string `xml:"name,attr"` + Value string `xml:"value,attr"` +} + +type JUnitTestCase struct { + // Name maps onto the full text of the spec - equivalent to "[SpecReport.LeafNodeType] SpecReport.FullText()" + Name string `xml:"name,attr"` + // Classname maps onto the name of the test suite - equivalent to Report.SuiteDescription + Classname string `xml:"classname,attr"` + // Status maps onto the string representation of SpecReport.State + Status string `xml:"status,attr"` + // Time is the time in seconds to execute the spec - maps onto SpecReport.RunTime + Time float64 `xml:"time,attr"` + //Skipped is populated with a message if the test was skipped or pending + Skipped *JUnitSkipped `xml:"skipped,omitempty"` + //Error is populated if the test panicked or was interrupted + Error *JUnitError `xml:"error,omitempty"` + //Failure is populated if the test failed + Failure *JUnitFailure `xml:"failure,omitempty"` + //SystemOut maps onto any captured stdout/stderr output - maps onto SpecReport.CapturedStdOutErr + SystemOut string `xml:"system-out,omitempty"` + //SystemOut maps onto any captured GinkgoWriter output - maps onto SpecReport.CapturedGinkgoWriterOutput + SystemErr string `xml:"system-err,omitempty"` +} + +type JUnitSkipped struct { + // Message maps onto "pending" if the test was marked pending, "skipped" if the test was marked skipped, and "skipped - REASON" if the user called Skip(REASON) + Message string `xml:"message,attr"` +} + +type JUnitError struct { + //Message maps onto the panic/exception thrown - equivalent to SpecReport.Failure.ForwardedPanic - or to "interrupted" + Message string `xml:"message,attr"` + //Type is one of "panicked" or "interrupted" + Type string `xml:"type,attr"` + //Description maps onto the captured stack trace for a panic, or the failure message for an interrupt which will include the dump of running goroutines + Description string `xml:",chardata"` +} + +type JUnitFailure struct { + //Message maps onto the failure message - equivalent to SpecReport.Failure.Message + Message string `xml:"message,attr"` + //Type is "failed" + Type string `xml:"type,attr"` + //Description maps onto the location and stack trace of the failure + Description string `xml:",chardata"` +} + +func GenerateJUnitReport(report types.Report, dst string) error { + suite := JUnitTestSuite{ + Name: report.SuiteDescription, + Package: report.SuitePath, + Time: report.RunTime.Seconds(), + Timestamp: report.StartTime.Format("2006-01-02T15:04:05"), + Properties: JUnitProperties{ + Properties: []JUnitProperty{ + {"SuiteSucceeded", fmt.Sprintf("%t", report.SuiteSucceeded)}, + {"SuiteHasProgrammaticFocus", fmt.Sprintf("%t", report.SuiteHasProgrammaticFocus)}, + {"SpecialSuiteFailureReason", strings.Join(report.SpecialSuiteFailureReasons, ",")}, + {"SuiteLabels", fmt.Sprintf("[%s]", strings.Join(report.SuiteLabels, ","))}, + {"RandomSeed", fmt.Sprintf("%d", report.SuiteConfig.RandomSeed)}, + {"RandomizeAllSpecs", fmt.Sprintf("%t", report.SuiteConfig.RandomizeAllSpecs)}, + {"LabelFilter", report.SuiteConfig.LabelFilter}, + {"FocusStrings", strings.Join(report.SuiteConfig.FocusStrings, ",")}, + {"SkipStrings", strings.Join(report.SuiteConfig.SkipStrings, ",")}, + {"FocusFiles", strings.Join(report.SuiteConfig.FocusFiles, ";")}, + {"SkipFiles", strings.Join(report.SuiteConfig.SkipFiles, ";")}, + {"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)}, + {"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)}, + {"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)}, + {"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)}, + {"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)}, + {"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)}, + {"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode}, + }, + }, + } + for _, spec := range report.SpecReports { + name := fmt.Sprintf("[%s]", spec.LeafNodeType) + if spec.FullText() != "" { + name = name + " " + spec.FullText() + } + labels := spec.Labels() + if len(labels) > 0 { + name = name + " [" + strings.Join(labels, ", ") + "]" + } + + test := JUnitTestCase{ + Name: name, + Classname: report.SuiteDescription, + Status: spec.State.String(), + Time: spec.RunTime.Seconds(), + SystemOut: systemOutForUnstructuredReporters(spec), + SystemErr: systemErrForUnstructuredReporters(spec), + } + suite.Tests += 1 + + switch spec.State { + case types.SpecStateSkipped: + message := "skipped" + if spec.Failure.Message != "" { + message += " - " + spec.Failure.Message + } + test.Skipped = &JUnitSkipped{Message: message} + suite.Skipped += 1 + case types.SpecStatePending: + test.Skipped = &JUnitSkipped{Message: "pending"} + suite.Disabled += 1 + case types.SpecStateFailed: + test.Failure = &JUnitFailure{ + Message: spec.Failure.Message, + Type: "failed", + Description: failureDescriptionForUnstructuredReporters(spec), + } + suite.Failures += 1 + case types.SpecStateTimedout: + test.Failure = &JUnitFailure{ + Message: spec.Failure.Message, + Type: "timedout", + Description: failureDescriptionForUnstructuredReporters(spec), + } + suite.Failures += 1 + case types.SpecStateInterrupted: + test.Error = &JUnitError{ + Message: spec.Failure.Message, + Type: "interrupted", + Description: failureDescriptionForUnstructuredReporters(spec), + } + suite.Errors += 1 + case types.SpecStateAborted: + test.Failure = &JUnitFailure{ + Message: spec.Failure.Message, + Type: "aborted", + Description: failureDescriptionForUnstructuredReporters(spec), + } + suite.Errors += 1 + case types.SpecStatePanicked: + test.Error = &JUnitError{ + Message: spec.Failure.ForwardedPanic, + Type: "panicked", + Description: failureDescriptionForUnstructuredReporters(spec), + } + suite.Errors += 1 + } + + suite.TestCases = append(suite.TestCases, test) + } + + junitReport := JUnitTestSuites{ + Tests: suite.Tests, + Disabled: suite.Disabled + suite.Skipped, + Errors: suite.Errors, + Failures: suite.Failures, + Time: suite.Time, + TestSuites: []JUnitTestSuite{suite}, + } + + f, err := os.Create(dst) + if err != nil { + return err + } + f.WriteString(xml.Header) + encoder := xml.NewEncoder(f) + encoder.Indent(" ", " ") + encoder.Encode(junitReport) + + return f.Close() +} + +func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error) { + messages := []string{} + mergedReport := JUnitTestSuites{} + for _, source := range sources { + report := JUnitTestSuites{} + f, err := os.Open(source) + if err != nil { + messages = append(messages, fmt.Sprintf("Could not open %s:\n%s", source, err.Error())) + continue + } + err = xml.NewDecoder(f).Decode(&report) + if err != nil { + messages = append(messages, fmt.Sprintf("Could not decode %s:\n%s", source, err.Error())) + continue + } + os.Remove(source) + + mergedReport.Tests += report.Tests + mergedReport.Disabled += report.Disabled + mergedReport.Errors += report.Errors + mergedReport.Failures += report.Failures + mergedReport.Time += report.Time + mergedReport.TestSuites = append(mergedReport.TestSuites, report.TestSuites...) + } + + f, err := os.Create(dst) + if err != nil { + return messages, err + } + f.WriteString(xml.Header) + encoder := xml.NewEncoder(f) + encoder.Indent(" ", " ") + encoder.Encode(mergedReport) + + return messages, f.Close() +} + +func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string { + out := &strings.Builder{} + out.WriteString(spec.Failure.Location.String() + "\n") + out.WriteString(spec.Failure.Location.FullStackTrace) + if !spec.Failure.ProgressReport.IsZero() { + out.WriteString("\n") + NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(spec.Failure.ProgressReport) + } + if len(spec.AdditionalFailures) > 0 { + out.WriteString("\nThere were additional failures detected after the initial failure:\n") + for i, additionalFailure := range spec.AdditionalFailures { + NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitFailure(0, additionalFailure.State, additionalFailure.Failure, true) + if i < len(spec.AdditionalFailures)-1 { + out.WriteString("----------\n") + } + } + } + return out.String() +} + +func systemErrForUnstructuredReporters(spec types.SpecReport) string { + out := &strings.Builder{} + gw := spec.CapturedGinkgoWriterOutput + cursor := 0 + for _, pr := range spec.ProgressReports { + if cursor < pr.GinkgoWriterOffset { + if pr.GinkgoWriterOffset < len(gw) { + out.WriteString(gw[cursor:pr.GinkgoWriterOffset]) + cursor = pr.GinkgoWriterOffset + } else if cursor < len(gw) { + out.WriteString(gw[cursor:]) + cursor = len(gw) + } + } + NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr) + } + + if cursor < len(gw) { + out.WriteString(gw[cursor:]) + } + + return out.String() +} + +func systemOutForUnstructuredReporters(spec types.SpecReport) string { + systemOut := spec.CapturedStdOutErr + if len(spec.ReportEntries) > 0 { + systemOut += "\nReport Entries:\n" + for i, entry := range spec.ReportEntries { + systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano)) + if representation := entry.StringRepresentation(); representation != "" { + systemOut += representation + "\n" + } + if i+1 < len(spec.ReportEntries) { + systemOut += "--\n" + } + } + } + return systemOut +} + +// Deprecated JUnitReporter (so folks can still compile their suites) +type JUnitReporter struct{} + +func NewJUnitReporter(_ string) *JUnitReporter { return &JUnitReporter{} } +func (reporter *JUnitReporter) SuiteWillBegin(_ config.GinkgoConfigType, _ *types.SuiteSummary) {} +func (reporter *JUnitReporter) BeforeSuiteDidRun(_ *types.SetupSummary) {} +func (reporter *JUnitReporter) SpecWillRun(_ *types.SpecSummary) {} +func (reporter *JUnitReporter) SpecDidComplete(_ *types.SpecSummary) {} +func (reporter *JUnitReporter) AfterSuiteDidRun(_ *types.SetupSummary) {} +func (reporter *JUnitReporter) SuiteDidEnd(_ *types.SuiteSummary) {} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go new file mode 100644 index 00000000000..f79f005dbe0 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go @@ -0,0 +1,21 @@ +package reporters + +import ( + "github.com/onsi/ginkgo/v2/types" +) + +type Reporter interface { + SuiteWillBegin(report types.Report) + WillRun(report types.SpecReport) + DidRun(report types.SpecReport) + SuiteDidEnd(report types.Report) + EmitProgressReport(progressReport types.ProgressReport) +} + +type NoopReporter struct{} + +func (n NoopReporter) SuiteWillBegin(report types.Report) {} +func (n NoopReporter) WillRun(report types.SpecReport) {} +func (n NoopReporter) DidRun(report types.SpecReport) {} +func (n NoopReporter) SuiteDidEnd(report types.Report) {} +func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go b/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go new file mode 100644 index 00000000000..c1863496dc7 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go @@ -0,0 +1,101 @@ +/* + +TeamCity Reporter for Ginkgo + +Makes use of TeamCity's support for Service Messages +http://confluence.jetbrains.com/display/TCD7/Build+Script+Interaction+with+TeamCity#BuildScriptInteractionwithTeamCity-ReportingTests +*/ + +package reporters + +import ( + "fmt" + "os" + "strings" + + "github.com/onsi/ginkgo/v2/types" +) + +func tcEscape(s string) string { + s = strings.ReplaceAll(s, "|", "||") + s = strings.ReplaceAll(s, "'", "|'") + s = strings.ReplaceAll(s, "\n", "|n") + s = strings.ReplaceAll(s, "\r", "|r") + s = strings.ReplaceAll(s, "[", "|[") + s = strings.ReplaceAll(s, "]", "|]") + return s +} + +func GenerateTeamcityReport(report types.Report, dst string) error { + f, err := os.Create(dst) + if err != nil { + return err + } + + name := report.SuiteDescription + labels := report.SuiteLabels + if len(labels) > 0 { + name = name + " [" + strings.Join(labels, ", ") + "]" + } + fmt.Fprintf(f, "##teamcity[testSuiteStarted name='%s']\n", tcEscape(name)) + for _, spec := range report.SpecReports { + name := fmt.Sprintf("[%s]", spec.LeafNodeType) + if spec.FullText() != "" { + name = name + " " + spec.FullText() + } + labels := spec.Labels() + if len(labels) > 0 { + name = name + " [" + strings.Join(labels, ", ") + "]" + } + + name = tcEscape(name) + fmt.Fprintf(f, "##teamcity[testStarted name='%s']\n", name) + switch spec.State { + case types.SpecStatePending: + fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='pending']\n", name) + case types.SpecStateSkipped: + message := "skipped" + if spec.Failure.Message != "" { + message += " - " + spec.Failure.Message + } + fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message)) + case types.SpecStateFailed: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) + case types.SpecStatePanicked: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details)) + case types.SpecStateTimedout: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='timedout - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) + case types.SpecStateInterrupted: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) + case types.SpecStateAborted: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) + } + + fmt.Fprintf(f, "##teamcity[testStdOut name='%s' out='%s']\n", name, tcEscape(systemOutForUnstructuredReporters(spec))) + fmt.Fprintf(f, "##teamcity[testStdErr name='%s' out='%s']\n", name, tcEscape(systemErrForUnstructuredReporters(spec))) + fmt.Fprintf(f, "##teamcity[testFinished name='%s' duration='%d']\n", name, int(spec.RunTime.Seconds()*1000.0)) + } + fmt.Fprintf(f, "##teamcity[testSuiteFinished name='%s']\n", tcEscape(report.SuiteDescription)) + + return f.Close() +} + +func MergeAndCleanupTeamcityReports(sources []string, dst string) ([]string, error) { + messages := []string{} + merged := []byte{} + for _, source := range sources { + data, err := os.ReadFile(source) + if err != nil { + messages = append(messages, fmt.Sprintf("Could not open %s:\n%s", source, err.Error())) + continue + } + os.Remove(source) + merged = append(merged, data...) + } + return messages, os.WriteFile(dst, merged, 0666) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/code_location.go b/vendor/github.com/onsi/ginkgo/v2/types/code_location.go new file mode 100644 index 00000000000..e4e9e38c6e4 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/code_location.go @@ -0,0 +1,91 @@ +package types +import ( + "fmt" + "os" + "regexp" + "runtime" + "runtime/debug" + "strings" +) + +type CodeLocation struct { + FileName string `json:",omitempty"` + LineNumber int `json:",omitempty"` + FullStackTrace string `json:",omitempty"` + CustomMessage string `json:",omitempty"` +} + +func (codeLocation CodeLocation) String() string { + if codeLocation.CustomMessage != "" { + return codeLocation.CustomMessage + } + return fmt.Sprintf("%s:%d", codeLocation.FileName, codeLocation.LineNumber) +} + +func (codeLocation CodeLocation) ContentsOfLine() string { + if codeLocation.CustomMessage != "" { + return "" + } + contents, err := os.ReadFile(codeLocation.FileName) + if err != nil { + return "" + } + lines := strings.Split(string(contents), "\n") + if len(lines) < codeLocation.LineNumber { + return "" + } + return lines[codeLocation.LineNumber-1] +} + +func NewCustomCodeLocation(message string) CodeLocation { + return CodeLocation{ + CustomMessage: message, + } +} + +func NewCodeLocation(skip int) CodeLocation { + _, file, line, _ := runtime.Caller(skip + 1) + return CodeLocation{FileName: file, LineNumber: line} +} + +func NewCodeLocationWithStackTrace(skip int) CodeLocation { + _, file, line, _ := runtime.Caller(skip + 1) + stackTrace := PruneStack(string(debug.Stack()), skip+1) + return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} +} + +// PruneStack removes references to functions that are internal to Ginkgo +// and the Go runtime from a stack string and a certain number of stack entries +// at the beginning of the stack. The stack string has the format +// as returned by runtime/debug.Stack. The leading goroutine information is +// optional and always removed if present. Beware that runtime/debug.Stack +// adds itself as first entry, so typically skip must be >= 1 to remove that +// entry. +func PruneStack(fullStackTrace string, skip int) string { + stack := strings.Split(fullStackTrace, "\n") + // Ensure that the even entries are the method names and the + // odd entries the source code information. + if len(stack) > 0 && strings.HasPrefix(stack[0], "goroutine ") { + // Ignore "goroutine 29 [running]:" line. + stack = stack[1:] + } + // The "+1" is for skipping over the initial entry, which is + // runtime/debug.Stack() itself. + if len(stack) > 2*(skip+1) { + stack = stack[2*(skip+1):] + } + prunedStack := []string{} + if os.Getenv("GINKGO_PRUNE_STACK") == "FALSE" { + prunedStack = stack + } else { + re := regexp.MustCompile(`\/ginkgo\/|\/pkg\/testing\/|\/pkg\/runtime\/`) + for i := 0; i < len(stack)/2; i++ { + // We filter out based on the source code file name. + if !re.Match([]byte(stack[i*2+1])) { + prunedStack = append(prunedStack, stack[i*2]) + prunedStack = append(prunedStack, stack[i*2+1]) + } + } + } + return strings.Join(prunedStack, "\n") +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/config.go b/vendor/github.com/onsi/ginkgo/v2/types/config.go new file mode 100644 index 00000000000..f016c5c1ffc --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/config.go @@ -0,0 +1,740 @@ +/* +Ginkgo accepts a number of configuration options. +These are documented [here](http://onsi.github.io/ginkgo/#the-ginkgo-cli) +*/ + +package types + +import ( + "flag" + "os" + "runtime" + "strconv" + "strings" + "time" +) + +// Configuration controlling how an individual test suite is run +type SuiteConfig struct { + RandomSeed int64 + RandomizeAllSpecs bool + FocusStrings []string + SkipStrings []string + FocusFiles []string + SkipFiles []string + LabelFilter string + FailOnPending bool + FailFast bool + FlakeAttempts int + EmitSpecProgress bool + DryRun bool + PollProgressAfter time.Duration + PollProgressInterval time.Duration + Timeout time.Duration + OutputInterceptorMode string + SourceRoots []string + GracePeriod time.Duration + + ParallelProcess int + ParallelTotal int + ParallelHost string +} + +func NewDefaultSuiteConfig() SuiteConfig { + return SuiteConfig{ + RandomSeed: time.Now().Unix(), + Timeout: time.Hour, + ParallelProcess: 1, + ParallelTotal: 1, + GracePeriod: 30 * time.Second, + } +} + +type VerbosityLevel uint + +const ( + VerbosityLevelSuccinct VerbosityLevel = iota + VerbosityLevelNormal + VerbosityLevelVerbose + VerbosityLevelVeryVerbose +) + +func (vl VerbosityLevel) GT(comp VerbosityLevel) bool { + return vl > comp +} + +func (vl VerbosityLevel) GTE(comp VerbosityLevel) bool { + return vl >= comp +} + +func (vl VerbosityLevel) Is(comp VerbosityLevel) bool { + return vl == comp +} + +func (vl VerbosityLevel) LTE(comp VerbosityLevel) bool { + return vl <= comp +} + +func (vl VerbosityLevel) LT(comp VerbosityLevel) bool { + return vl < comp +} + +// Configuration for Ginkgo's reporter +type ReporterConfig struct { + NoColor bool + SlowSpecThreshold time.Duration + Succinct bool + Verbose bool + VeryVerbose bool + FullTrace bool + AlwaysEmitGinkgoWriter bool + + JSONReport string + JUnitReport string + TeamcityReport string +} + +func (rc ReporterConfig) Verbosity() VerbosityLevel { + if rc.Succinct { + return VerbosityLevelSuccinct + } else if rc.Verbose { + return VerbosityLevelVerbose + } else if rc.VeryVerbose { + return VerbosityLevelVeryVerbose + } + return VerbosityLevelNormal +} + +func (rc ReporterConfig) WillGenerateReport() bool { + return rc.JSONReport != "" || rc.JUnitReport != "" || rc.TeamcityReport != "" +} + +func NewDefaultReporterConfig() ReporterConfig { + return ReporterConfig{ + SlowSpecThreshold: 5 * time.Second, + } +} + +// Configuration for the Ginkgo CLI +type CLIConfig struct { + //for build, run, and watch + Recurse bool + SkipPackage string + RequireSuite bool + NumCompilers int + + //for run and watch only + Procs int + Parallel bool + AfterRunHook string + OutputDir string + KeepSeparateCoverprofiles bool + KeepSeparateReports bool + + //for run only + KeepGoing bool + UntilItFails bool + Repeat int + RandomizeSuites bool + + //for watch only + Depth int + WatchRegExp string +} + +func NewDefaultCLIConfig() CLIConfig { + return CLIConfig{ + Depth: 1, + WatchRegExp: `\.go$`, + } +} + +func (g CLIConfig) ComputedProcs() int { + if g.Procs > 0 { + return g.Procs + } + + n := 1 + if g.Parallel { + n = runtime.NumCPU() + if n > 4 { + n = n - 1 + } + } + return n +} + +func (g CLIConfig) ComputedNumCompilers() int { + if g.NumCompilers > 0 { + return g.NumCompilers + } + + return runtime.NumCPU() +} + +// Configuration for the Ginkgo CLI capturing available go flags +// A subset of Go flags are exposed by Ginkgo. Some are available at compile time (e.g. ginkgo build) and others only at run time (e.g. ginkgo run - which has both build and run time flags). +// More details can be found at: +// https://docs.google.com/spreadsheets/d/1zkp-DS4hU4sAJl5eHh1UmgwxCPQhf3s5a8fbiOI8tJU/ +type GoFlagsConfig struct { + //build-time flags for code-and-performance analysis + Race bool + Cover bool + CoverMode string + CoverPkg string + Vet string + + //run-time flags for code-and-performance analysis + BlockProfile string + BlockProfileRate int + CoverProfile string + CPUProfile string + MemProfile string + MemProfileRate int + MutexProfile string + MutexProfileFraction int + Trace string + + //build-time flags for building + A bool + ASMFlags string + BuildMode string + Compiler string + GCCGoFlags string + GCFlags string + InstallSuffix string + LDFlags string + LinkShared bool + Mod string + N bool + ModFile string + ModCacheRW bool + MSan bool + PkgDir string + Tags string + TrimPath bool + ToolExec string + Work bool + X bool +} + +func NewDefaultGoFlagsConfig() GoFlagsConfig { + return GoFlagsConfig{} +} + +func (g GoFlagsConfig) BinaryMustBePreserved() bool { + return g.BlockProfile != "" || g.CPUProfile != "" || g.MemProfile != "" || g.MutexProfile != "" +} + +// Configuration that were deprecated in 2.0 +type deprecatedConfig struct { + DebugParallel bool + NoisySkippings bool + NoisyPendings bool + RegexScansFilePath bool + SlowSpecThresholdWithFLoatUnits float64 + Stream bool + Notify bool +} + +// Flags + +// Flags sections used by both the CLI and the Ginkgo test process +var FlagSections = GinkgoFlagSections{ + {Key: "multiple-suites", Style: "{{dark-green}}", Heading: "Running Multiple Test Suites"}, + {Key: "order", Style: "{{green}}", Heading: "Controlling Test Order"}, + {Key: "parallel", Style: "{{yellow}}", Heading: "Controlling Test Parallelism"}, + {Key: "low-level-parallel", Style: "{{yellow}}", Heading: "Controlling Test Parallelism", + Description: "These are set by the Ginkgo CLI, {{red}}{{bold}}do not set them manually{{/}} via go test.\nUse ginkgo -p or ginkgo -procs=N instead."}, + {Key: "filter", Style: "{{cyan}}", Heading: "Filtering Tests"}, + {Key: "failure", Style: "{{red}}", Heading: "Failure Handling"}, + {Key: "output", Style: "{{magenta}}", Heading: "Controlling Output Formatting"}, + {Key: "code-and-coverage-analysis", Style: "{{orange}}", Heading: "Code and Coverage Analysis"}, + {Key: "performance-analysis", Style: "{{coral}}", Heading: "Performance Analysis"}, + {Key: "debug", Style: "{{blue}}", Heading: "Debugging Tests", + Description: "In addition to these flags, Ginkgo supports a few debugging environment variables. To change the parallel server protocol set {{blue}}GINKGO_PARALLEL_PROTOCOL{{/}} to {{bold}}HTTP{{/}}. To avoid pruning callstacks set {{blue}}GINKGO_PRUNE_STACK{{/}} to {{bold}}FALSE{{/}}."}, + {Key: "watch", Style: "{{light-yellow}}", Heading: "Controlling Ginkgo Watch"}, + {Key: "misc", Style: "{{light-gray}}", Heading: "Miscellaneous"}, + {Key: "go-build", Style: "{{light-gray}}", Heading: "Go Build Flags", Succinct: true, + Description: "These flags are inherited from go build. Run {{bold}}ginkgo help build{{/}} for more detailed flag documentation."}, +} + +// SuiteConfigFlags provides flags for the Ginkgo test process, and CLI +var SuiteConfigFlags = GinkgoFlags{ + {KeyPath: "S.RandomSeed", Name: "seed", SectionKey: "order", UsageDefaultValue: "randomly generated by Ginkgo", + Usage: "The seed used to randomize the spec suite."}, + {KeyPath: "S.RandomizeAllSpecs", Name: "randomize-all", SectionKey: "order", DeprecatedName: "randomizeAllSpecs", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will randomize all specs together. By default, ginkgo only randomizes the top level Describe, Context and When containers."}, + + {KeyPath: "S.FailOnPending", Name: "fail-on-pending", SectionKey: "failure", DeprecatedName: "failOnPending", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will mark the test suite as failed if any specs are pending."}, + {KeyPath: "S.FailFast", Name: "fail-fast", SectionKey: "failure", DeprecatedName: "failFast", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will stop running a test suite after a failure occurs."}, + {KeyPath: "S.FlakeAttempts", Name: "flake-attempts", SectionKey: "failure", UsageDefaultValue: "0 - failed tests are not retried", DeprecatedName: "flakeAttempts", DeprecatedDocLink: "changed-command-line-flags", + Usage: "Make up to this many attempts to run each spec. If any of the attempts succeed, the suite will not be failed."}, + + {KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."}, + {KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug", + Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."}, + {KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0", + Usage: "Emit node progress reports periodically if node hasn't completed after this duration."}, + {KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s", + Usage: "The rate at which to emit node progress reports after poll-progress-after has elapsed."}, + {KeyPath: "S.SourceRoots", Name: "source-root", SectionKey: "debug", + Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."}, + {KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h", + Usage: "Test suite fails if it does not complete within the specified timeout."}, + {KeyPath: "S.GracePeriod", Name: "grace-period", SectionKey: "debug", UsageDefaultValue: "30s", + Usage: "When interrupted, Ginkgo will wait for GracePeriod for the current running node to exit before moving on to the next one."}, + {KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none", + Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."}, + + {KeyPath: "S.LabelFilter", Name: "label-filter", SectionKey: "filter", UsageArgument: "expression", + Usage: "If set, ginkgo will only run specs with labels that match the label-filter. The passed-in expression can include boolean operations (!, &&, ||, ','), groupings via '()', and regular expressions '/regexp/'. e.g. '(cat || dog) && !fruit'"}, + {KeyPath: "S.FocusStrings", Name: "focus", SectionKey: "filter", + Usage: "If set, ginkgo will only run specs that match this regular expression. Can be specified multiple times, values are ORed."}, + {KeyPath: "S.SkipStrings", Name: "skip", SectionKey: "filter", + Usage: "If set, ginkgo will only run specs that do not match this regular expression. Can be specified multiple times, values are ORed."}, + {KeyPath: "S.FocusFiles", Name: "focus-file", SectionKey: "filter", UsageArgument: "file (regexp) | file:line | file:lineA-lineB | file:line,line,line", + Usage: "If set, ginkgo will only run specs in matching files. Can be specified multiple times, values are ORed."}, + {KeyPath: "S.SkipFiles", Name: "skip-file", SectionKey: "filter", UsageArgument: "file (regexp) | file:line | file:lineA-lineB | file:line,line,line", + Usage: "If set, ginkgo will skip specs in matching files. Can be specified multiple times, values are ORed."}, + + {KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"}, + {KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"}, +} + +// ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI) +var ParallelConfigFlags = GinkgoFlags{ + {KeyPath: "S.ParallelProcess", Name: "parallel.process", SectionKey: "low-level-parallel", UsageDefaultValue: "1", + Usage: "This worker process's (one-indexed) process number. For running specs in parallel."}, + {KeyPath: "S.ParallelTotal", Name: "parallel.total", SectionKey: "low-level-parallel", UsageDefaultValue: "1", + Usage: "The total number of worker processes. For running specs in parallel."}, + {KeyPath: "S.ParallelHost", Name: "parallel.host", SectionKey: "low-level-parallel", UsageDefaultValue: "set by Ginkgo CLI", + Usage: "The address for the server that will synchronize the processes."}, +} + +// ReporterConfigFlags provides flags for the Ginkgo test process, and CLI +var ReporterConfigFlags = GinkgoFlags{ + {KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, suppress color output in default reporter."}, + {KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s", + Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."}, + {KeyPath: "R.Verbose", Name: "v", SectionKey: "output", + Usage: "If set, emits more output including GinkgoWriter contents."}, + {KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output", + Usage: "If set, emits with maximal verbosity - includes skipped and pending tests."}, + {KeyPath: "R.Succinct", Name: "succinct", SectionKey: "output", + Usage: "If set, default reporter prints out a very succinct report"}, + {KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output", + Usage: "If set, default reporter prints out the full stack trace when a failure occurs"}, + {KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed", + Usage: "If set, default reporter prints out captured output of passed tests."}, + + {KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output", + Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."}, + {KeyPath: "R.JUnitReport", Name: "junit-report", UsageArgument: "filename.xml", SectionKey: "output", DeprecatedName: "reportFile", DeprecatedDocLink: "improved-reporting-infrastructure", + Usage: "If set, Ginkgo will generate a conformant junit test report in the specified file."}, + {KeyPath: "R.TeamcityReport", Name: "teamcity-report", UsageArgument: "filename", SectionKey: "output", + Usage: "If set, Ginkgo will generate a Teamcity-formatted test report at the specified location."}, + + {KeyPath: "D.SlowSpecThresholdWithFLoatUnits", DeprecatedName: "slowSpecThreshold", DeprecatedDocLink: "changed--slowspecthreshold", + Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"}, + {KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, + {KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, +} + +// BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process +func BuildTestSuiteFlagSet(suiteConfig *SuiteConfig, reporterConfig *ReporterConfig) (GinkgoFlagSet, error) { + flags := SuiteConfigFlags.CopyAppend(ParallelConfigFlags...).CopyAppend(ReporterConfigFlags...) + flags = flags.WithPrefix("ginkgo") + bindings := map[string]interface{}{ + "S": suiteConfig, + "R": reporterConfig, + "D": &deprecatedConfig{}, + } + extraGoFlagsSection := GinkgoFlagSection{Style: "{{gray}}", Heading: "Go test flags"} + + return NewAttachedGinkgoFlagSet(flag.CommandLine, flags, bindings, FlagSections, extraGoFlagsSection) +} + +// VetConfig validates that the Ginkgo test process' configuration is sound +func VetConfig(flagSet GinkgoFlagSet, suiteConfig SuiteConfig, reporterConfig ReporterConfig) []error { + errors := []error{} + + if flagSet.WasSet("count") || flagSet.WasSet("test.count") { + flag := flagSet.Lookup("count") + if flag == nil { + flag = flagSet.Lookup("test.count") + } + count, err := strconv.Atoi(flag.Value.String()) + if err != nil || count != 1 { + errors = append(errors, GinkgoErrors.InvalidGoFlagCount()) + } + } + + if flagSet.WasSet("parallel") || flagSet.WasSet("test.parallel") { + errors = append(errors, GinkgoErrors.InvalidGoFlagParallel()) + } + + if suiteConfig.ParallelTotal < 1 { + errors = append(errors, GinkgoErrors.InvalidParallelTotalConfiguration()) + } + + if suiteConfig.ParallelProcess > suiteConfig.ParallelTotal || suiteConfig.ParallelProcess < 1 { + errors = append(errors, GinkgoErrors.InvalidParallelProcessConfiguration()) + } + + if suiteConfig.ParallelTotal > 1 && suiteConfig.ParallelHost == "" { + errors = append(errors, GinkgoErrors.MissingParallelHostConfiguration()) + } + + if suiteConfig.DryRun && suiteConfig.ParallelTotal > 1 { + errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration()) + } + + if suiteConfig.GracePeriod <= 0 { + errors = append(errors, GinkgoErrors.GracePeriodCannotBeZero()) + } + + if len(suiteConfig.FocusFiles) > 0 { + _, err := ParseFileFilters(suiteConfig.FocusFiles) + if err != nil { + errors = append(errors, err) + } + } + + if len(suiteConfig.SkipFiles) > 0 { + _, err := ParseFileFilters(suiteConfig.SkipFiles) + if err != nil { + errors = append(errors, err) + } + } + + if suiteConfig.LabelFilter != "" { + _, err := ParseLabelFilter(suiteConfig.LabelFilter) + if err != nil { + errors = append(errors, err) + } + } + + switch strings.ToLower(suiteConfig.OutputInterceptorMode) { + case "", "dup", "swap", "none": + default: + errors = append(errors, GinkgoErrors.InvalidOutputInterceptorModeConfiguration(suiteConfig.OutputInterceptorMode)) + } + + numVerbosity := 0 + for _, v := range []bool{reporterConfig.Succinct, reporterConfig.Verbose, reporterConfig.VeryVerbose} { + if v { + numVerbosity++ + } + } + if numVerbosity > 1 { + errors = append(errors, GinkgoErrors.ConflictingVerbosityConfiguration()) + } + + return errors +} + +// GinkgoCLISharedFlags provides flags shared by the Ginkgo CLI's build, watch, and run commands +var GinkgoCLISharedFlags = GinkgoFlags{ + {KeyPath: "C.Recurse", Name: "r", SectionKey: "multiple-suites", + Usage: "If set, ginkgo finds and runs test suites under the current directory recursively."}, + {KeyPath: "C.SkipPackage", Name: "skip-package", SectionKey: "multiple-suites", DeprecatedName: "skipPackage", DeprecatedDocLink: "changed-command-line-flags", + UsageArgument: "comma-separated list of packages", + Usage: "A comma-separated list of package names to be skipped. If any part of the package's path matches, that package is ignored."}, + {KeyPath: "C.RequireSuite", Name: "require-suite", SectionKey: "failure", DeprecatedName: "requireSuite", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, Ginkgo fails if there are ginkgo tests in a directory but no invocation of RunSpecs."}, + {KeyPath: "C.NumCompilers", Name: "compilers", SectionKey: "multiple-suites", UsageDefaultValue: "0 (will autodetect)", + Usage: "When running multiple packages, the number of concurrent compilations to perform."}, +} + +// GinkgoCLIRunAndWatchFlags provides flags shared by the Ginkgo CLI's build and watch commands (but not run) +var GinkgoCLIRunAndWatchFlags = GinkgoFlags{ + {KeyPath: "C.Procs", Name: "procs", SectionKey: "parallel", UsageDefaultValue: "1 (run in series)", + Usage: "The number of parallel test nodes to run."}, + {KeyPath: "C.Procs", Name: "nodes", SectionKey: "parallel", UsageDefaultValue: "1 (run in series)", + Usage: "--nodes is an alias for --procs"}, + {KeyPath: "C.Parallel", Name: "p", SectionKey: "parallel", + Usage: "If set, ginkgo will run in parallel with an auto-detected number of nodes."}, + {KeyPath: "C.AfterRunHook", Name: "after-run-hook", SectionKey: "misc", DeprecatedName: "afterSuiteHook", DeprecatedDocLink: "changed-command-line-flags", + Usage: "Command to run when a test suite completes."}, + {KeyPath: "C.OutputDir", Name: "output-dir", SectionKey: "output", UsageArgument: "directory", DeprecatedName: "outputdir", DeprecatedDocLink: "improved-profiling-support", + Usage: "A location to place all generated profiles and reports."}, + {KeyPath: "C.KeepSeparateCoverprofiles", Name: "keep-separate-coverprofiles", SectionKey: "code-and-coverage-analysis", + Usage: "If set, Ginkgo does not merge coverprofiles into one monolithic coverprofile. The coverprofiles will remain in their respective package directories or in -output-dir if set."}, + {KeyPath: "C.KeepSeparateReports", Name: "keep-separate-reports", SectionKey: "output", + Usage: "If set, Ginkgo does not merge per-suite reports (e.g. -json-report) into one monolithic report for the entire testrun. The reports will remain in their respective package directories or in -output-dir if set."}, + + {KeyPath: "D.Stream", DeprecatedName: "stream", DeprecatedDocLink: "removed--stream", DeprecatedVersion: "2.0.0"}, + {KeyPath: "D.Notify", DeprecatedName: "notify", DeprecatedDocLink: "removed--notify", DeprecatedVersion: "2.0.0"}, +} + +// GinkgoCLIRunFlags provides flags for Ginkgo CLI's run command that aren't shared by any other commands +var GinkgoCLIRunFlags = GinkgoFlags{ + {KeyPath: "C.KeepGoing", Name: "keep-going", SectionKey: "multiple-suites", DeprecatedName: "keepGoing", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, failures from earlier test suites do not prevent later test suites from running."}, + {KeyPath: "C.UntilItFails", Name: "until-it-fails", SectionKey: "debug", DeprecatedName: "untilItFails", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will keep rerunning test suites until a failure occurs."}, + {KeyPath: "C.Repeat", Name: "repeat", SectionKey: "debug", UsageArgument: "n", UsageDefaultValue: "0 - i.e. no repetition, run only once", + Usage: "The number of times to re-run a test-suite. Useful for debugging flaky tests. If set to N the suite will be run N+1 times and will be required to pass each time."}, + {KeyPath: "C.RandomizeSuites", Name: "randomize-suites", SectionKey: "order", DeprecatedName: "randomizeSuites", DeprecatedDocLink: "changed-command-line-flags", + Usage: "If set, ginkgo will randomize the order in which test suites run."}, +} + +// GinkgoCLIRunFlags provides flags for Ginkgo CLI's watch command that aren't shared by any other commands +var GinkgoCLIWatchFlags = GinkgoFlags{ + {KeyPath: "C.Depth", Name: "depth", SectionKey: "watch", + Usage: "Ginkgo will watch dependencies down to this depth in the dependency tree."}, + {KeyPath: "C.WatchRegExp", Name: "watch-regexp", SectionKey: "watch", DeprecatedName: "watchRegExp", DeprecatedDocLink: "changed-command-line-flags", + UsageArgument: "Regular Expression", + UsageDefaultValue: `\.go$`, + Usage: "Only files matching this regular expression will be watched for changes."}, +} + +// GoBuildFlags provides flags for the Ginkgo CLI build, run, and watch commands that capture go's build-time flags. These are passed to go test -c by the ginkgo CLI +var GoBuildFlags = GinkgoFlags{ + {KeyPath: "Go.Race", Name: "race", SectionKey: "code-and-coverage-analysis", + Usage: "enable data race detection. Supported only on linux/amd64, freebsd/amd64, darwin/amd64, windows/amd64, linux/ppc64le and linux/arm64 (only for 48-bit VMA)."}, + {KeyPath: "Go.Vet", Name: "vet", UsageArgument: "list", SectionKey: "code-and-coverage-analysis", + Usage: `Configure the invocation of "go vet" during "go test" to use the comma-separated list of vet checks. If list is empty, "go test" runs "go vet" with a curated list of checks believed to be always worth addressing. If list is "off", "go test" does not run "go vet" at all. Available checks can be found by running 'go doc cmd/vet'`}, + {KeyPath: "Go.Cover", Name: "cover", SectionKey: "code-and-coverage-analysis", + Usage: "Enable coverage analysis. Note that because coverage works by annotating the source code before compilation, compilation and test failures with coverage enabled may report line numbers that don't correspond to the original sources."}, + {KeyPath: "Go.CoverMode", Name: "covermode", UsageArgument: "set,count,atomic", SectionKey: "code-and-coverage-analysis", + Usage: `Set the mode for coverage analysis for the package[s] being tested. 'set': does this statement run? 'count': how many times does this statement run? 'atomic': like count, but correct in multithreaded tests and more expensive (must use atomic with -race). Sets -cover`}, + {KeyPath: "Go.CoverPkg", Name: "coverpkg", UsageArgument: "pattern1,pattern2,pattern3", SectionKey: "code-and-coverage-analysis", + Usage: "Apply coverage analysis in each test to packages matching the patterns. The default is for each test to analyze only the package being tested. See 'go help packages' for a description of package patterns. Sets -cover."}, + + {KeyPath: "Go.A", Name: "a", SectionKey: "go-build", + Usage: "force rebuilding of packages that are already up-to-date."}, + {KeyPath: "Go.ASMFlags", Name: "asmflags", UsageArgument: "'[pattern=]arg list'", SectionKey: "go-build", + Usage: "arguments to pass on each go tool asm invocation."}, + {KeyPath: "Go.BuildMode", Name: "buildmode", UsageArgument: "mode", SectionKey: "go-build", + Usage: "build mode to use. See 'go help buildmode' for more."}, + {KeyPath: "Go.Compiler", Name: "compiler", UsageArgument: "name", SectionKey: "go-build", + Usage: "name of compiler to use, as in runtime.Compiler (gccgo or gc)."}, + {KeyPath: "Go.GCCGoFlags", Name: "gccgoflags", UsageArgument: "'[pattern=]arg list'", SectionKey: "go-build", + Usage: "arguments to pass on each gccgo compiler/linker invocation."}, + {KeyPath: "Go.GCFlags", Name: "gcflags", UsageArgument: "'[pattern=]arg list'", SectionKey: "go-build", + Usage: "arguments to pass on each go tool compile invocation."}, + {KeyPath: "Go.InstallSuffix", Name: "installsuffix", SectionKey: "go-build", + Usage: "a suffix to use in the name of the package installation directory, in order to keep output separate from default builds. If using the -race flag, the install suffix is automatically set to raceor, if set explicitly, has _race appended to it. Likewise for the -msan flag. Using a -buildmode option that requires non-default compile flags has a similar effect."}, + {KeyPath: "Go.LDFlags", Name: "ldflags", UsageArgument: "'[pattern=]arg list'", SectionKey: "go-build", + Usage: "arguments to pass on each go tool link invocation."}, + {KeyPath: "Go.LinkShared", Name: "linkshared", SectionKey: "go-build", + Usage: "build code that will be linked against shared libraries previously created with -buildmode=shared."}, + {KeyPath: "Go.Mod", Name: "mod", UsageArgument: "mode (readonly, vendor, or mod)", SectionKey: "go-build", + Usage: "module download mode to use: readonly, vendor, or mod. See 'go help modules' for more."}, + {KeyPath: "Go.ModCacheRW", Name: "modcacherw", SectionKey: "go-build", + Usage: "leave newly-created directories in the module cache read-write instead of making them read-only."}, + {KeyPath: "Go.ModFile", Name: "modfile", UsageArgument: "file", SectionKey: "go-build", + Usage: `in module aware mode, read (and possibly write) an alternate go.mod file instead of the one in the module root directory. A file named go.mod must still be present in order to determine the module root directory, but it is not accessed. When -modfile is specified, an alternate go.sum file is also used: its path is derived from the -modfile flag by trimming the ".mod" extension and appending ".sum".`}, + {KeyPath: "Go.MSan", Name: "msan", SectionKey: "go-build", + Usage: "enable interoperation with memory sanitizer. Supported only on linux/amd64, linux/arm64 and only with Clang/LLVM as the host C compiler. On linux/arm64, pie build mode will be used."}, + {KeyPath: "Go.N", Name: "n", SectionKey: "go-build", + Usage: "print the commands but do not run them."}, + {KeyPath: "Go.PkgDir", Name: "pkgdir", UsageArgument: "dir", SectionKey: "go-build", + Usage: "install and load all packages from dir instead of the usual locations. For example, when building with a non-standard configuration, use -pkgdir to keep generated packages in a separate location."}, + {KeyPath: "Go.Tags", Name: "tags", UsageArgument: "tag,list", SectionKey: "go-build", + Usage: "a comma-separated list of build tags to consider satisfied during the build. For more information about build tags, see the description of build constraints in the documentation for the go/build package. (Earlier versions of Go used a space-separated list, and that form is deprecated but still recognized.)"}, + {KeyPath: "Go.TrimPath", Name: "trimpath", SectionKey: "go-build", + Usage: `remove all file system paths from the resulting executable. Instead of absolute file system paths, the recorded file names will begin with either "go" (for the standard library), or a module path@version (when using modules), or a plain import path (when using GOPATH).`}, + {KeyPath: "Go.ToolExec", Name: "toolexec", UsageArgument: "'cmd args'", SectionKey: "go-build", + Usage: "a program to use to invoke toolchain programs like vet and asm. For example, instead of running asm, the go command will run cmd args /path/to/asm '."}, + {KeyPath: "Go.Work", Name: "work", SectionKey: "go-build", + Usage: "print the name of the temporary work directory and do not delete it when exiting."}, + {KeyPath: "Go.X", Name: "x", SectionKey: "go-build", + Usage: "print the commands."}, +} + +// GoRunFlags provides flags for the Ginkgo CLI run, and watch commands that capture go's run-time flags. These are passed to the compiled test binary by the ginkgo CLI +var GoRunFlags = GinkgoFlags{ + {KeyPath: "Go.CoverProfile", Name: "coverprofile", UsageArgument: "file", SectionKey: "code-and-coverage-analysis", + Usage: `Write a coverage profile to the file after all tests have passed. Sets -cover.`}, + {KeyPath: "Go.BlockProfile", Name: "blockprofile", UsageArgument: "file", SectionKey: "performance-analysis", + Usage: `Write a goroutine blocking profile to the specified file when all tests are complete. Preserves test binary.`}, + {KeyPath: "Go.BlockProfileRate", Name: "blockprofilerate", UsageArgument: "rate", SectionKey: "performance-analysis", + Usage: `Control the detail provided in goroutine blocking profiles by calling runtime.SetBlockProfileRate with rate. See 'go doc runtime.SetBlockProfileRate'. The profiler aims to sample, on average, one blocking event every n nanoseconds the program spends blocked. By default, if -test.blockprofile is set without this flag, all blocking events are recorded, equivalent to -test.blockprofilerate=1.`}, + {KeyPath: "Go.CPUProfile", Name: "cpuprofile", UsageArgument: "file", SectionKey: "performance-analysis", + Usage: `Write a CPU profile to the specified file before exiting. Preserves test binary.`}, + {KeyPath: "Go.MemProfile", Name: "memprofile", UsageArgument: "file", SectionKey: "performance-analysis", + Usage: `Write an allocation profile to the file after all tests have passed. Preserves test binary.`}, + {KeyPath: "Go.MemProfileRate", Name: "memprofilerate", UsageArgument: "rate", SectionKey: "performance-analysis", + Usage: `Enable more precise (and expensive) memory allocation profiles by setting runtime.MemProfileRate. See 'go doc runtime.MemProfileRate'. To profile all memory allocations, use -test.memprofilerate=1.`}, + {KeyPath: "Go.MutexProfile", Name: "mutexprofile", UsageArgument: "file", SectionKey: "performance-analysis", + Usage: `Write a mutex contention profile to the specified file when all tests are complete. Preserves test binary.`}, + {KeyPath: "Go.MutexProfileFraction", Name: "mutexprofilefraction", UsageArgument: "n", SectionKey: "performance-analysis", + Usage: `if >= 0, calls runtime.SetMutexProfileFraction() Sample 1 in n stack traces of goroutines holding a contended mutex.`}, + {KeyPath: "Go.Trace", Name: "execution-trace", UsageArgument: "file", ExportAs: "trace", SectionKey: "performance-analysis", + Usage: `Write an execution trace to the specified file before exiting.`}, +} + +// VetAndInitializeCLIAndGoConfig validates that the Ginkgo CLI's configuration is sound +// It returns a potentially mutated copy of the config that rationalizes the configuration to ensure consistency for downstream consumers +func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsConfig) (CLIConfig, GoFlagsConfig, []error) { + errors := []error{} + + if cliConfig.Repeat > 0 && cliConfig.UntilItFails { + errors = append(errors, GinkgoErrors.BothRepeatAndUntilItFails()) + } + + //initialize the output directory + if cliConfig.OutputDir != "" { + err := os.MkdirAll(cliConfig.OutputDir, 0777) + if err != nil { + errors = append(errors, err) + } + } + + //ensure cover mode is configured appropriately + if goFlagsConfig.CoverMode != "" || goFlagsConfig.CoverPkg != "" || goFlagsConfig.CoverProfile != "" { + goFlagsConfig.Cover = true + } + if goFlagsConfig.Cover && goFlagsConfig.CoverProfile == "" { + goFlagsConfig.CoverProfile = "coverprofile.out" + } + + return cliConfig, goFlagsConfig, errors +} + +// GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test +func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) { + // if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure + // the built test binary can generate a coverprofile + if goFlagsConfig.CoverProfile != "" { + goFlagsConfig.Cover = true + } + + args := []string{"test", "-c", "-o", destination, packageToBuild} + goArgs, err := GenerateFlagArgs( + GoBuildFlags, + map[string]interface{}{ + "Go": &goFlagsConfig, + }, + ) + + if err != nil { + return []string{}, err + } + args = append(args, goArgs...) + return args, nil +} + +// GenerateGinkgoTestRunArgs is used by the Ginkgo CLI to generate command line arguments to pass to the compiled Ginkgo test binary +func GenerateGinkgoTestRunArgs(suiteConfig SuiteConfig, reporterConfig ReporterConfig, goFlagsConfig GoFlagsConfig) ([]string, error) { + var flags GinkgoFlags + flags = SuiteConfigFlags.WithPrefix("ginkgo") + flags = flags.CopyAppend(ParallelConfigFlags.WithPrefix("ginkgo")...) + flags = flags.CopyAppend(ReporterConfigFlags.WithPrefix("ginkgo")...) + flags = flags.CopyAppend(GoRunFlags.WithPrefix("test")...) + bindings := map[string]interface{}{ + "S": &suiteConfig, + "R": &reporterConfig, + "Go": &goFlagsConfig, + } + + return GenerateFlagArgs(flags, bindings) +} + +// GenerateGoTestRunArgs is used by the Ginkgo CLI to generate command line arguments to pass to the compiled non-Ginkgo test binary +func GenerateGoTestRunArgs(goFlagsConfig GoFlagsConfig) ([]string, error) { + flags := GoRunFlags.WithPrefix("test") + bindings := map[string]interface{}{ + "Go": &goFlagsConfig, + } + + args, err := GenerateFlagArgs(flags, bindings) + if err != nil { + return args, err + } + args = append(args, "--test.v") + return args, nil +} + +// BuildRunCommandFlagSet builds the FlagSet for the `ginkgo run` command +func BuildRunCommandFlagSet(suiteConfig *SuiteConfig, reporterConfig *ReporterConfig, cliConfig *CLIConfig, goFlagsConfig *GoFlagsConfig) (GinkgoFlagSet, error) { + flags := SuiteConfigFlags + flags = flags.CopyAppend(ReporterConfigFlags...) + flags = flags.CopyAppend(GinkgoCLISharedFlags...) + flags = flags.CopyAppend(GinkgoCLIRunAndWatchFlags...) + flags = flags.CopyAppend(GinkgoCLIRunFlags...) + flags = flags.CopyAppend(GoBuildFlags...) + flags = flags.CopyAppend(GoRunFlags...) + + bindings := map[string]interface{}{ + "S": suiteConfig, + "R": reporterConfig, + "C": cliConfig, + "Go": goFlagsConfig, + "D": &deprecatedConfig{}, + } + + return NewGinkgoFlagSet(flags, bindings, FlagSections) +} + +// BuildWatchCommandFlagSet builds the FlagSet for the `ginkgo watch` command +func BuildWatchCommandFlagSet(suiteConfig *SuiteConfig, reporterConfig *ReporterConfig, cliConfig *CLIConfig, goFlagsConfig *GoFlagsConfig) (GinkgoFlagSet, error) { + flags := SuiteConfigFlags + flags = flags.CopyAppend(ReporterConfigFlags...) + flags = flags.CopyAppend(GinkgoCLISharedFlags...) + flags = flags.CopyAppend(GinkgoCLIRunAndWatchFlags...) + flags = flags.CopyAppend(GinkgoCLIWatchFlags...) + flags = flags.CopyAppend(GoBuildFlags...) + flags = flags.CopyAppend(GoRunFlags...) + + bindings := map[string]interface{}{ + "S": suiteConfig, + "R": reporterConfig, + "C": cliConfig, + "Go": goFlagsConfig, + "D": &deprecatedConfig{}, + } + + return NewGinkgoFlagSet(flags, bindings, FlagSections) +} + +// BuildBuildCommandFlagSet builds the FlagSet for the `ginkgo build` command +func BuildBuildCommandFlagSet(cliConfig *CLIConfig, goFlagsConfig *GoFlagsConfig) (GinkgoFlagSet, error) { + flags := GinkgoCLISharedFlags + flags = flags.CopyAppend(GoBuildFlags...) + + bindings := map[string]interface{}{ + "C": cliConfig, + "Go": goFlagsConfig, + "D": &deprecatedConfig{}, + } + + flagSections := make(GinkgoFlagSections, len(FlagSections)) + copy(flagSections, FlagSections) + for i := range flagSections { + if flagSections[i].Key == "multiple-suites" { + flagSections[i].Heading = "Building Multiple Suites" + } + if flagSections[i].Key == "go-build" { + flagSections[i] = GinkgoFlagSection{Key: "go-build", Style: "{{/}}", Heading: "Go Build Flags", + Description: "These flags are inherited from go build."} + } + } + + return NewGinkgoFlagSet(flags, bindings, flagSections) +} + +func BuildLabelsCommandFlagSet(cliConfig *CLIConfig) (GinkgoFlagSet, error) { + flags := GinkgoCLISharedFlags.SubsetWithNames("r", "skip-package") + + bindings := map[string]interface{}{ + "C": cliConfig, + } + + flagSections := make(GinkgoFlagSections, len(FlagSections)) + copy(flagSections, FlagSections) + for i := range flagSections { + if flagSections[i].Key == "multiple-suites" { + flagSections[i].Heading = "Fetching Labels from Multiple Suites" + } + } + + return NewGinkgoFlagSet(flags, bindings, flagSections) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/deprecated_types.go b/vendor/github.com/onsi/ginkgo/v2/types/deprecated_types.go new file mode 100644 index 00000000000..17922304b63 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/deprecated_types.go @@ -0,0 +1,141 @@ +package types + +import ( + "strconv" + "time" +) + +/* + A set of deprecations to make the transition from v1 to v2 easier for users who have written custom reporters. +*/ + +type SuiteSummary = DeprecatedSuiteSummary +type SetupSummary = DeprecatedSetupSummary +type SpecSummary = DeprecatedSpecSummary +type SpecMeasurement = DeprecatedSpecMeasurement +type SpecComponentType = NodeType +type SpecFailure = DeprecatedSpecFailure + +var ( + SpecComponentTypeInvalid = NodeTypeInvalid + SpecComponentTypeContainer = NodeTypeContainer + SpecComponentTypeIt = NodeTypeIt + SpecComponentTypeBeforeEach = NodeTypeBeforeEach + SpecComponentTypeJustBeforeEach = NodeTypeJustBeforeEach + SpecComponentTypeAfterEach = NodeTypeAfterEach + SpecComponentTypeJustAfterEach = NodeTypeJustAfterEach + SpecComponentTypeBeforeSuite = NodeTypeBeforeSuite + SpecComponentTypeSynchronizedBeforeSuite = NodeTypeSynchronizedBeforeSuite + SpecComponentTypeAfterSuite = NodeTypeAfterSuite + SpecComponentTypeSynchronizedAfterSuite = NodeTypeSynchronizedAfterSuite +) + +type DeprecatedSuiteSummary struct { + SuiteDescription string + SuiteSucceeded bool + SuiteID string + + NumberOfSpecsBeforeParallelization int + NumberOfTotalSpecs int + NumberOfSpecsThatWillBeRun int + NumberOfPendingSpecs int + NumberOfSkippedSpecs int + NumberOfPassedSpecs int + NumberOfFailedSpecs int + NumberOfFlakedSpecs int + RunTime time.Duration +} + +type DeprecatedSetupSummary struct { + ComponentType SpecComponentType + CodeLocation CodeLocation + + State SpecState + RunTime time.Duration + Failure SpecFailure + + CapturedOutput string + SuiteID string +} + +type DeprecatedSpecSummary struct { + ComponentTexts []string + ComponentCodeLocations []CodeLocation + + State SpecState + RunTime time.Duration + Failure SpecFailure + IsMeasurement bool + NumberOfSamples int + Measurements map[string]*DeprecatedSpecMeasurement + + CapturedOutput string + SuiteID string +} + +func (s DeprecatedSpecSummary) HasFailureState() bool { + return s.State.Is(SpecStateFailureStates) +} + +func (s DeprecatedSpecSummary) TimedOut() bool { + return false +} + +func (s DeprecatedSpecSummary) Panicked() bool { + return s.State == SpecStatePanicked +} + +func (s DeprecatedSpecSummary) Failed() bool { + return s.State == SpecStateFailed +} + +func (s DeprecatedSpecSummary) Passed() bool { + return s.State == SpecStatePassed +} + +func (s DeprecatedSpecSummary) Skipped() bool { + return s.State == SpecStateSkipped +} + +func (s DeprecatedSpecSummary) Pending() bool { + return s.State == SpecStatePending +} + +type DeprecatedSpecFailure struct { + Message string + Location CodeLocation + ForwardedPanic string + + ComponentIndex int + ComponentType SpecComponentType + ComponentCodeLocation CodeLocation +} + +type DeprecatedSpecMeasurement struct { + Name string + Info interface{} + Order int + + Results []float64 + + Smallest float64 + Largest float64 + Average float64 + StdDeviation float64 + + SmallestLabel string + LargestLabel string + AverageLabel string + Units string + Precision int +} + +func (s DeprecatedSpecMeasurement) PrecisionFmt() string { + if s.Precision == 0 { + return "%f" + } + + str := strconv.Itoa(s.Precision) + + return "%." + str + "f" +} diff --git a/vendor/github.com/onsi/ginkgo/types/deprecation_support.go b/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go similarity index 66% rename from vendor/github.com/onsi/ginkgo/types/deprecation_support.go rename to vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go index d5a6658f35f..2948dfa0c9d 100644 --- a/vendor/github.com/onsi/ginkgo/types/deprecation_support.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go @@ -4,10 +4,10 @@ import ( "os" "strconv" "strings" + "sync" "unicode" - "github.com/onsi/ginkgo/config" - "github.com/onsi/ginkgo/formatter" + "github.com/onsi/ginkgo/v2/formatter" ) type Deprecation struct { @@ -22,20 +22,12 @@ var Deprecations = deprecations{} func (d deprecations) CustomReporter() Deprecation { return Deprecation{ - Message: "You are using a custom reporter. Support for custom reporters will likely be removed in V2. Most users were using them to generate junit or teamcity reports and this functionality will be merged into the core reporter. In addition, Ginkgo 2.0 will support emitting a JSON-formatted report that users can then manipulate to generate custom reports.\n\n{{red}}{{bold}}If this change will be impactful to you please leave a comment on {{cyan}}{{underline}}https://github.com/onsi/ginkgo/issues/711{{/}}", + Message: "Support for custom reporters has been removed in V2. Please read the documentation linked to below for Ginkgo's new behavior and for a migration path:", DocLink: "removed-custom-reporters", Version: "1.16.0", } } -func (d deprecations) V1Reporter() Deprecation { - return Deprecation{ - Message: "You are using a V1 Ginkgo Reporter. Please update your custom reporter to the new V2 Reporter interface.", - DocLink: "changed-reporter-interface", - Version: "1.16.0", - } -} - func (d deprecations) Async() Deprecation { return Deprecation{ Message: "You are passing a Done channel to a test node to test asynchronous behavior. This is deprecated in Ginkgo V2. Your test will run synchronously and the timeout will be ignored.", @@ -56,7 +48,15 @@ func (d deprecations) ParallelNode() Deprecation { return Deprecation{ Message: "GinkgoParallelNode is deprecated and will be removed in Ginkgo V2. Please use GinkgoParallelProcess instead.", DocLink: "renamed-ginkgoparallelnode", - Version: "1.16.5", + Version: "1.16.4", + } +} + +func (d deprecations) CurrentGinkgoTestDescription() Deprecation { + return Deprecation{ + Message: "CurrentGinkgoTestDescription() is deprecated in Ginkgo V2. Use CurrentSpecReport() instead.", + DocLink: "changed-currentginkgotestdescription", + Version: "1.16.0", } } @@ -75,13 +75,23 @@ func (d deprecations) Blur() Deprecation { } } +func (d deprecations) Nodot() Deprecation { + return Deprecation{ + Message: "The nodot command is deprecated in Ginkgo V2. Please either dot-import Ginkgo or use the package identifier in your code to references objects and types provided by Ginkgo and Gomega.", + DocLink: "removed-ginkgo-nodot", + Version: "1.16.0", + } +} + type DeprecationTracker struct { deprecations map[Deprecation][]CodeLocation + lock *sync.Mutex } func NewDeprecationTracker() *DeprecationTracker { return &DeprecationTracker{ deprecations: map[Deprecation][]CodeLocation{}, + lock: &sync.Mutex{}, } } @@ -95,6 +105,8 @@ func (d *DeprecationTracker) TrackDeprecation(deprecation Deprecation, cl ...Cod } } + d.lock.Lock() + defer d.lock.Unlock() if len(cl) == 1 { d.deprecations[deprecation] = append(d.deprecations[deprecation], cl[0]) } else { @@ -103,29 +115,27 @@ func (d *DeprecationTracker) TrackDeprecation(deprecation Deprecation, cl ...Cod } func (d *DeprecationTracker) DidTrackDeprecations() bool { + d.lock.Lock() + defer d.lock.Unlock() return len(d.deprecations) > 0 } func (d *DeprecationTracker) DeprecationsReport() string { - out := formatter.F("\n{{light-yellow}}You're using deprecated Ginkgo functionality:{{/}}\n") + d.lock.Lock() + defer d.lock.Unlock() + out := formatter.F("{{light-yellow}}You're using deprecated Ginkgo functionality:{{/}}\n") out += formatter.F("{{light-yellow}}============================================={{/}}\n") - out += formatter.F("{{bold}}{{green}}Ginkgo 2.0{{/}} is under active development and will introduce several new features, improvements, and a small handful of breaking changes.\n") - out += formatter.F("A release candidate for 2.0 is now available and 2.0 should GA in Fall 2021. {{bold}}Please give the RC a try and send us feedback!{{/}}\n") - out += formatter.F(" - To learn more, view the migration guide at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/ver2/docs/MIGRATING_TO_V2.md{{/}}\n") - out += formatter.F(" - For instructions on using the Release Candidate visit {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/ver2/docs/MIGRATING_TO_V2.md#using-the-beta{{/}}\n") - out += formatter.F(" - To comment, chime in at {{cyan}}{{underline}}https://github.com/onsi/ginkgo/issues/711{{/}}\n\n") - for deprecation, locations := range d.deprecations { out += formatter.Fi(1, "{{yellow}}"+deprecation.Message+"{{/}}\n") if deprecation.DocLink != "" { - out += formatter.Fi(1, "{{bold}}Learn more at:{{/}} {{cyan}}{{underline}}https://github.com/onsi/ginkgo/blob/ver2/docs/MIGRATING_TO_V2.md#%s{{/}}\n", deprecation.DocLink) + out += formatter.Fi(1, "{{bold}}Learn more at:{{/}} {{cyan}}{{underline}}https://onsi.github.io/ginkgo/MIGRATING_TO_V2#%s{{/}}\n", deprecation.DocLink) } for _, location := range locations { out += formatter.Fi(2, "{{gray}}%s{{/}}\n", location) } } out += formatter.F("\n{{gray}}To silence deprecations that can be silenced set the following environment variable:{{/}}\n") - out += formatter.Fi(1, "{{gray}}ACK_GINKGO_DEPRECATIONS=%s{{/}}\n", config.VERSION) + out += formatter.Fi(1, "{{gray}}ACK_GINKGO_DEPRECATIONS=%s{{/}}\n", VERSION) return out } diff --git a/vendor/github.com/onsi/ginkgo/v2/types/enum_support.go b/vendor/github.com/onsi/ginkgo/v2/types/enum_support.go new file mode 100644 index 00000000000..1d96ae02800 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/enum_support.go @@ -0,0 +1,43 @@ +package types + +import "encoding/json" + +type EnumSupport struct { + toString map[uint]string + toEnum map[string]uint + maxEnum uint +} + +func NewEnumSupport(toString map[uint]string) EnumSupport { + toEnum, maxEnum := map[string]uint{}, uint(0) + for k, v := range toString { + toEnum[v] = k + if maxEnum < k { + maxEnum = k + } + } + return EnumSupport{toString: toString, toEnum: toEnum, maxEnum: maxEnum} +} + +func (es EnumSupport) String(e uint) string { + if e > es.maxEnum { + return es.toString[0] + } + return es.toString[e] +} + +func (es EnumSupport) UnmarshJSON(b []byte) (uint, error) { + var dec string + if err := json.Unmarshal(b, &dec); err != nil { + return 0, err + } + out := es.toEnum[dec] // if we miss we get 0 which is what we want anyway + return out, nil +} + +func (es EnumSupport) MarshJSON(e uint) ([]byte, error) { + if e == 0 || e > es.maxEnum { + return json.Marshal(nil) + } + return json.Marshal(es.toString[e]) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/errors.go b/vendor/github.com/onsi/ginkgo/v2/types/errors.go new file mode 100644 index 00000000000..b7ed5a21ebb --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/errors.go @@ -0,0 +1,621 @@ +package types + +import ( + "fmt" + "reflect" + "strings" + + "github.com/onsi/ginkgo/v2/formatter" +) + +type GinkgoError struct { + Heading string + Message string + DocLink string + CodeLocation CodeLocation +} + +func (g GinkgoError) Error() string { + out := formatter.F("{{bold}}{{red}}%s{{/}}\n", g.Heading) + if (g.CodeLocation != CodeLocation{}) { + contentsOfLine := strings.TrimLeft(g.CodeLocation.ContentsOfLine(), "\t ") + if contentsOfLine != "" { + out += formatter.F("{{light-gray}}%s{{/}}\n", contentsOfLine) + } + out += formatter.F("{{gray}}%s{{/}}\n", g.CodeLocation) + } + if g.Message != "" { + out += formatter.Fiw(1, formatter.COLS, g.Message) + out += "\n\n" + } + if g.DocLink != "" { + out += formatter.Fiw(1, formatter.COLS, "{{bold}}Learn more at:{{/}} {{cyan}}{{underline}}http://onsi.github.io/ginkgo/#%s{{/}}\n", g.DocLink) + } + + return out +} + +type ginkgoErrors struct{} + +var GinkgoErrors = ginkgoErrors{} + +func (g ginkgoErrors) UncaughtGinkgoPanic(cl CodeLocation) error { + return GinkgoError{ + Heading: "Your Test Panicked", + Message: `When you, or your assertion library, calls Ginkgo's Fail(), +Ginkgo panics to prevent subsequent assertions from running. + +Normally Ginkgo rescues this panic so you shouldn't see it. + +However, if you make an assertion in a goroutine, Ginkgo can't capture the panic. +To circumvent this, you should call + + defer GinkgoRecover() + +at the top of the goroutine that caused this panic. + +Alternatively, you may have made an assertion outside of a Ginkgo +leaf node (e.g. in a container node or some out-of-band function) - please move your assertion to +an appropriate Ginkgo node (e.g. a BeforeSuite, BeforeEach, It, etc...).`, + DocLink: "mental-model-how-ginkgo-handles-failure", + CodeLocation: cl, + } +} + +func (g ginkgoErrors) RerunningSuite() error { + return GinkgoError{ + Heading: "Rerunning Suite", + Message: formatter.F(`It looks like you are calling RunSpecs more than once. Ginkgo does not support rerunning suites. If you want to rerun a suite try {{bold}}ginkgo --repeat=N{{/}} or {{bold}}ginkgo --until-it-fails{{/}}`), + DocLink: "repeating-spec-runs-and-managing-flaky-specs", + } +} + +/* Tree construction errors */ + +func (g ginkgoErrors) PushingNodeInRunPhase(nodeType NodeType, cl CodeLocation) error { + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F( + `It looks like you are trying to add a {{bold}}[%s]{{/}} node +to the Ginkgo spec tree in a leaf node {{bold}}after{{/}} the specs started running. + +To enable randomization and parallelization Ginkgo requires the spec tree +to be fully constructed up front. In practice, this means that you can +only create nodes like {{bold}}[%s]{{/}} at the top-level or within the +body of a {{bold}}Describe{{/}}, {{bold}}Context{{/}}, or {{bold}}When{{/}}.`, nodeType, nodeType), + CodeLocation: cl, + DocLink: "mental-model-how-ginkgo-traverses-the-spec-hierarchy", + } +} + +func (g ginkgoErrors) CaughtPanicDuringABuildPhase(caughtPanic interface{}, cl CodeLocation) error { + return GinkgoError{ + Heading: "Assertion or Panic detected during tree construction", + Message: formatter.F( + `Ginkgo detected a panic while constructing the spec tree. +You may be trying to make an assertion in the body of a container node +(i.e. {{bold}}Describe{{/}}, {{bold}}Context{{/}}, or {{bold}}When{{/}}). + +Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/}}, +{{bold}}It{{/}}, etc. + +{{bold}}Here's the content of the panic that was caught:{{/}} +%v`, caughtPanic), + CodeLocation: cl, + DocLink: "no-assertions-in-container-nodes", + } +} + +func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error { + docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" + if nodeType.Is(NodeTypeReportAfterSuite) { + docLink = "reporting-nodes---reportaftersuite" + } + + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F( + `It looks like you are trying to add a {{bold}}[%s]{{/}} node within a container node. + +{{bold}}%s{{/}} can only be called at the top level.`, nodeType, nodeType), + CodeLocation: cl, + DocLink: docLink, + } +} + +func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error { + docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" + if nodeType.Is(NodeTypeReportAfterSuite) { + docLink = "reporting-nodes---reportaftersuite" + } + + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F( + `It looks like you are trying to add a {{bold}}[%s]{{/}} node within a leaf node after the spec started running. + +{{bold}}%s{{/}} can only be called at the top level.`, nodeType, nodeType), + CodeLocation: cl, + DocLink: docLink, + } +} + +func (g ginkgoErrors) MultipleBeforeSuiteNodes(nodeType NodeType, cl CodeLocation, earlierNodeType NodeType, earlierCodeLocation CodeLocation) error { + return ginkgoErrorMultipleSuiteNodes("setup", nodeType, cl, earlierNodeType, earlierCodeLocation) +} + +func (g ginkgoErrors) MultipleAfterSuiteNodes(nodeType NodeType, cl CodeLocation, earlierNodeType NodeType, earlierCodeLocation CodeLocation) error { + return ginkgoErrorMultipleSuiteNodes("teardown", nodeType, cl, earlierNodeType, earlierCodeLocation) +} + +func ginkgoErrorMultipleSuiteNodes(setupOrTeardown string, nodeType NodeType, cl CodeLocation, earlierNodeType NodeType, earlierCodeLocation CodeLocation) error { + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F( + `It looks like you are trying to add a {{bold}}[%s]{{/}} node but +you already have a {{bold}}[%s]{{/}} node defined at: {{gray}}%s{{/}}. + +Ginkgo only allows you to define one suite %s node.`, nodeType, earlierNodeType, earlierCodeLocation, setupOrTeardown), + CodeLocation: cl, + DocLink: "suite-setup-and-cleanup-beforesuite-and-aftersuite", + } +} + +/* Decorator errors */ +func (g ginkgoErrors) InvalidDecoratorForNodeType(cl CodeLocation, nodeType NodeType, decorator string) error { + return GinkgoError{ + Heading: "Invalid Decorator", + Message: formatter.F(`[%s] node cannot be passed a(n) '%s' decorator`, nodeType, decorator), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidDeclarationOfFocusedAndPending(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Combination of Decorators: Focused and Pending", + Message: formatter.F(`[%s] node was decorated with both Focus and Pending. At most one is allowed.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Combination of Decorators: FlakeAttempts and MustPassRepeatedly", + Message: formatter.F(`[%s] node was decorated with both FlakeAttempts and MustPassRepeatedly. At most one is allowed.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error { + return GinkgoError{ + Heading: "Unknown Decorator", + Message: formatter.F(`[%s] node was passed an unknown decorator: '%#v'`, nodeType, decorator), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidBodyTypeForContainer(t reflect.Type, cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. You passed {{bold}}%s{{/}} instead.`, nodeType, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error { + mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}" + if nodeType.Is(NodeTypeContainer) { + mustGet = "{{bold}}func(){{/}}" + } + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[%s] node must be passed `+mustGet+`. +You passed {{bold}}%s{{/}} instead.`, nodeType, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t reflect.Type, cl CodeLocation) error { + mustGet := "{{bold}}func() []byte{{/}}, {{bold}}func(ctx SpecContext) []byte{{/}}, or {{bold}}func(ctx context.Context) []byte{{/}}, {{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}" + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its first function. +You passed {{bold}}%s{{/}} instead.`, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t reflect.Type, cl CodeLocation) error { + mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}, {{bold}}func([]byte){{/}}, {{bold}}func(ctx SpecContext, []byte){{/}}, or {{bold}}func(ctx context.Context, []byte){{/}}" + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its second function. +You passed {{bold}}%s{{/}} instead.`, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Multiple Functions", + Message: formatter.F(`[%s] node must be passed a single function - but more than one was passed in.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Missing Functions", + Message: formatter.F(`[%s] node must be passed a single function - but none was passed in.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextNode(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod", + Message: formatter.F(`[%s] was passed NodeTimeout, SpecTimeout, or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`, nodeType), + CodeLocation: cl, + DocLink: "spec-timeouts-and-interruptible-nodes", + } +} + +func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextCleanupNode(cl CodeLocation) error { + return GinkgoError{ + Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod", + Message: formatter.F(`[DeferCleanup] was passed NodeTimeout or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`), + CodeLocation: cl, + DocLink: "spec-timeouts-and-interruptible-nodes", + } +} + +/* Ordered Container errors */ +func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Serial Node in Non-Serial Ordered Container", + Message: formatter.F(`[%s] node was decorated with Serial but occurs in an Ordered container that is not marked Serial. Move the Serial decorator to the outer-most Ordered container to mark all ordered specs within the container as serial.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Setup Node not in Ordered Container", + Message: fmt.Sprintf("[%s] setup nodes must appear inside an Ordered container. They cannot be nested within other containers, even containers in an ordered container.", nodeType), + CodeLocation: cl, + DocLink: "ordered-containers", + } +} + +/* DeferCleanup errors */ +func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error { + return GinkgoError{ + Heading: "DeferCleanup requires a valid function", + Message: "You must pass DeferCleanup a function to invoke. This function must return zero or one values - if it does return, it must return an error. The function can take arbitrarily many arguments and you should provide these to DeferCleanup to pass along to the function.", + CodeLocation: cl, + DocLink: "cleaning-up-our-cleanup-code-defercleanup", + } +} + +func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation) error { + return GinkgoError{ + Heading: "DeferCleanup must be called inside a setup or subject node", + Message: "You must call DeferCleanup inside a setup node (e.g. BeforeEach, BeforeSuite, AfterAll...) or a subject node (i.e. It). You can't call DeferCleanup at the top-level or in a container node - use the After* family of setup nodes instead.", + CodeLocation: cl, + DocLink: "cleaning-up-our-cleanup-code-defercleanup", + } +} + +func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType), + Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.", + CodeLocation: cl, + DocLink: "cleaning-up-our-cleanup-code-defercleanup", + } +} + +func (g ginkgoErrors) PushingCleanupInCleanupNode(cl CodeLocation) error { + return GinkgoError{ + Heading: "DeferCleanup cannot be called in a DeferCleanup callback", + Message: "Please inline your cleanup code - Ginkgo doesn't let you call DeferCleanup from within DeferCleanup", + CodeLocation: cl, + DocLink: "cleaning-up-our-cleanup-code-defercleanup", + } +} + +/* ReportEntry errors */ +func (g ginkgoErrors) TooManyReportEntryValues(cl CodeLocation, arg interface{}) error { + return GinkgoError{ + Heading: "Too Many ReportEntry Values", + Message: formatter.F(`{{bold}}AddGinkgoReport{{/}} can only be given one value. Got unexpected value: %#v`, arg), + CodeLocation: cl, + DocLink: "attaching-data-to-reports", + } +} + +func (g ginkgoErrors) AddReportEntryNotDuringRunPhase(cl CodeLocation) error { + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F(`It looks like you are calling {{bold}}AddGinkgoReport{{/}} outside of a running spec. Make sure you call {{bold}}AddGinkgoReport{{/}} inside a runnable node such as It or BeforeEach and not inside the body of a container such as Describe or Context.`), + CodeLocation: cl, + DocLink: "attaching-data-to-reports", + } +} + +/* By errors */ +func (g ginkgoErrors) ByNotDuringRunPhase(cl CodeLocation) error { + return GinkgoError{ + Heading: "Ginkgo detected an issue with your spec structure", + Message: formatter.F(`It looks like you are calling {{bold}}By{{/}} outside of a running spec. Make sure you call {{bold}}By{{/}} inside a runnable node such as It or BeforeEach and not inside the body of a container such as Describe or Context.`), + CodeLocation: cl, + DocLink: "documenting-complex-specs-by", + } +} + +/* FileFilter and SkipFilter errors */ +func (g ginkgoErrors) InvalidFileFilter(filter string) error { + return GinkgoError{ + Heading: "Invalid File Filter", + Message: fmt.Sprintf(`The provided file filter: "%s" is invalid. File filters must have the format "file", "file:lines" where "file" is a regular expression that will match against the file path and lines is a comma-separated list of integers (e.g. file:1,5,7) or line-ranges (e.g. file:1-3,5-9) or both (e.g. file:1,5-9)`, filter), + DocLink: "filtering-specs", + } +} + +func (g ginkgoErrors) InvalidFileFilterRegularExpression(filter string, err error) error { + return GinkgoError{ + Heading: "Invalid File Filter Regular Expression", + Message: fmt.Sprintf(`The provided file filter: "%s" included an invalid regular expression. regexp.Compile error: %s`, filter, err), + DocLink: "filtering-specs", + } +} + +/* Label Errors */ +func (g ginkgoErrors) SyntaxErrorParsingLabelFilter(input string, location int, error string) error { + var message string + if location >= 0 { + for i, r := range input { + if i == location { + message += "{{red}}{{bold}}{{underline}}" + } + message += string(r) + if i == location { + message += "{{/}}" + } + } + } else { + message = input + } + message += "\n" + error + return GinkgoError{ + Heading: "Syntax Error Parsing Label Filter", + Message: message, + DocLink: "spec-labels", + } +} + +func (g ginkgoErrors) InvalidLabel(label string, cl CodeLocation) error { + return GinkgoError{ + Heading: "Invalid Label", + Message: fmt.Sprintf("'%s' is an invalid label. Labels cannot contain of the following characters: '&|!,()/'", label), + CodeLocation: cl, + DocLink: "spec-labels", + } +} + +func (g ginkgoErrors) InvalidEmptyLabel(cl CodeLocation) error { + return GinkgoError{ + Heading: "Invalid Empty Label", + Message: "Labels cannot be empty", + CodeLocation: cl, + DocLink: "spec-labels", + } +} + +/* Table errors */ +func (g ginkgoErrors) MultipleEntryBodyFunctionsForTable(cl CodeLocation) error { + return GinkgoError{ + Heading: "DescribeTable passed multiple functions", + Message: "It looks like you are passing multiple functions into DescribeTable. Only one function can be passed in. This function will be called for each Entry in the table.", + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) InvalidEntryDescription(cl CodeLocation) error { + return GinkgoError{ + Heading: "Invalid Entry description", + Message: "Entry description functions must be a string, a function that accepts the entry parameters and returns a string, or nil.", + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) MissingParametersForTableFunction(cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("No parameters have been passed to the Table Function"), + Message: fmt.Sprintf("The Table Function expected at least 1 parameter"), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error { + return GinkgoError{ + Heading: "DescribeTable passed incorrect parameter type", + Message: fmt.Sprintf("Parameter #%d passed to DescribeTable is of incorrect type <%s>", i, name), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) TooFewParametersToTableFunction(expected, actual int, kind string, cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("Too few parameters passed in to %s", kind), + Message: fmt.Sprintf("The %s expected %d parameters but you passed in %d", kind, expected, actual), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) TooManyParametersToTableFunction(expected, actual int, kind string, cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("Too many parameters passed in to %s", kind), + Message: fmt.Sprintf("The %s expected %d parameters but you passed in %d", kind, expected, actual), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) IncorrectParameterTypeToTableFunction(i int, expected, actual reflect.Type, kind string, cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("Incorrect parameters type passed to %s", kind), + Message: fmt.Sprintf("The %s expected parameter #%d to be of type <%s> but you passed in <%s>", kind, i, expected, actual), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +func (g ginkgoErrors) IncorrectVariadicParameterTypeToTableFunction(expected, actual reflect.Type, kind string, cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("Incorrect parameters type passed to %s", kind), + Message: fmt.Sprintf("The %s expected its variadic parameters to be of type <%s> but you passed in <%s>", kind, expected, actual), + CodeLocation: cl, + DocLink: "table-specs", + } +} + +/* Parallel Synchronization errors */ + +func (g ginkgoErrors) AggregatedReportUnavailableDueToNodeDisappearing() error { + return GinkgoError{ + Heading: "Test Report unavailable because a Ginkgo parallel process disappeared", + Message: "The aggregated report could not be fetched for a ReportAfterSuite node. A Ginkgo parallel process disappeared before it could finish reporting.", + } +} + +func (g ginkgoErrors) SynchronizedBeforeSuiteFailedOnProc1() error { + return GinkgoError{ + Heading: "SynchronizedBeforeSuite failed on Ginkgo parallel process #1", + Message: "The first SynchronizedBeforeSuite function running on Ginkgo parallel process #1 failed. This suite will now abort.", + } +} + +func (g ginkgoErrors) SynchronizedBeforeSuiteDisappearedOnProc1() error { + return GinkgoError{ + Heading: "Process #1 disappeared before SynchronizedBeforeSuite could report back", + Message: "Ginkgo parallel process #1 disappeared before the first SynchronizedBeforeSuite function completed. This suite will now abort.", + } +} + +/* Configuration errors */ + +func (g ginkgoErrors) UnknownTypePassedToRunSpecs(value interface{}) error { + return GinkgoError{ + Heading: "Unknown Type passed to RunSpecs", + Message: fmt.Sprintf("RunSpecs() accepts labels, and configuration of type types.SuiteConfig and/or types.ReporterConfig.\n You passed in: %v", value), + } +} + +var sharedParallelErrorMessage = "It looks like you are trying to run specs in parallel with go test.\nThis is unsupported and you should use the ginkgo CLI instead." + +func (g ginkgoErrors) InvalidParallelTotalConfiguration() error { + return GinkgoError{ + Heading: "-ginkgo.parallel.total must be >= 1", + Message: sharedParallelErrorMessage, + DocLink: "spec-parallelization", + } +} + +func (g ginkgoErrors) InvalidParallelProcessConfiguration() error { + return GinkgoError{ + Heading: "-ginkgo.parallel.process is one-indexed and must be <= ginkgo.parallel.total", + Message: sharedParallelErrorMessage, + DocLink: "spec-parallelization", + } +} + +func (g ginkgoErrors) MissingParallelHostConfiguration() error { + return GinkgoError{ + Heading: "-ginkgo.parallel.host is missing", + Message: sharedParallelErrorMessage, + DocLink: "spec-parallelization", + } +} + +func (g ginkgoErrors) UnreachableParallelHost(host string) error { + return GinkgoError{ + Heading: "Could not reach ginkgo.parallel.host:" + host, + Message: sharedParallelErrorMessage, + DocLink: "spec-parallelization", + } +} + +func (g ginkgoErrors) DryRunInParallelConfiguration() error { + return GinkgoError{ + Heading: "Ginkgo only performs -dryRun in serial mode.", + Message: "Please try running ginkgo -dryRun again, but without -p or -procs to ensure the suite is running in series.", + } +} + +func (g ginkgoErrors) GracePeriodCannotBeZero() error { + return GinkgoError{ + Heading: "Ginkgo requires a positive --grace-period.", + Message: "Please set --grace-period to a positive duration. The default is 30s.", + } +} + +func (g ginkgoErrors) ConflictingVerbosityConfiguration() error { + return GinkgoError{ + Heading: "Conflicting reporter verbosity settings.", + Message: "You can't set more than one of -v, -vv and --succinct. Please pick one!", + } +} + +func (g ginkgoErrors) InvalidOutputInterceptorModeConfiguration(value string) error { + return GinkgoError{ + Heading: fmt.Sprintf("Invalid value '%s' for --output-interceptor-mode.", value), + Message: "You must choose one of 'dup', 'swap', or 'none'.", + } +} + +func (g ginkgoErrors) InvalidGoFlagCount() error { + return GinkgoError{ + Heading: "Use of go test -count", + Message: "Ginkgo does not support using go test -count to rerun suites. Only -count=1 is allowed. To repeat suite runs, please use the ginkgo cli and `ginkgo -until-it-fails` or `ginkgo -repeat=N`.", + } +} + +func (g ginkgoErrors) InvalidGoFlagParallel() error { + return GinkgoError{ + Heading: "Use of go test -parallel", + Message: "Go test's implementation of parallelization does not actually parallelize Ginkgo specs. Please use the ginkgo cli and `ginkgo -p` or `ginkgo -procs=N` instead.", + } +} + +func (g ginkgoErrors) BothRepeatAndUntilItFails() error { + return GinkgoError{ + Heading: "--repeat and --until-it-fails are both set", + Message: "--until-it-fails directs Ginkgo to rerun specs indefinitely until they fail. --repeat directs Ginkgo to rerun specs a set number of times. You can't set both... which would you like?", + } +} + +/* Stack-Trace parsing errors */ + +func (g ginkgoErrors) FailedToParseStackTrace(message string) error { + return GinkgoError{ + Heading: "Failed to Parse Stack Trace", + Message: message, + } +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/file_filter.go b/vendor/github.com/onsi/ginkgo/v2/types/file_filter.go new file mode 100644 index 00000000000..cc21df71ec8 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/file_filter.go @@ -0,0 +1,106 @@ +package types + +import ( + "regexp" + "strconv" + "strings" +) + +func ParseFileFilters(filters []string) (FileFilters, error) { + ffs := FileFilters{} + for _, filter := range filters { + ff := FileFilter{} + if filter == "" { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + components := strings.Split(filter, ":") + if !(len(components) == 1 || len(components) == 2) { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + + var err error + ff.Filename, err = regexp.Compile(components[0]) + if err != nil { + return nil, err + } + if len(components) == 2 { + lineFilters := strings.Split(components[1], ",") + for _, lineFilter := range lineFilters { + components := strings.Split(lineFilter, "-") + if len(components) == 1 { + line, err := strconv.Atoi(strings.TrimSpace(components[0])) + if err != nil { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + ff.LineFilters = append(ff.LineFilters, LineFilter{line, line + 1}) + } else if len(components) == 2 { + line1, err := strconv.Atoi(strings.TrimSpace(components[0])) + if err != nil { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + line2, err := strconv.Atoi(strings.TrimSpace(components[1])) + if err != nil { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + ff.LineFilters = append(ff.LineFilters, LineFilter{line1, line2}) + } else { + return nil, GinkgoErrors.InvalidFileFilter(filter) + } + } + } + ffs = append(ffs, ff) + } + return ffs, nil +} + +type FileFilter struct { + Filename *regexp.Regexp + LineFilters LineFilters +} + +func (f FileFilter) Matches(locations []CodeLocation) bool { + for _, location := range locations { + if f.Filename.MatchString(location.FileName) && + f.LineFilters.Matches(location.LineNumber) { + return true + } + + } + return false +} + +type FileFilters []FileFilter + +func (ffs FileFilters) Matches(locations []CodeLocation) bool { + for _, ff := range ffs { + if ff.Matches(locations) { + return true + } + } + + return false +} + +type LineFilter struct { + Min int + Max int +} + +func (lf LineFilter) Matches(line int) bool { + return lf.Min <= line && line < lf.Max +} + +type LineFilters []LineFilter + +func (lfs LineFilters) Matches(line int) bool { + if len(lfs) == 0 { + return true + } + + for _, lf := range lfs { + if lf.Matches(line) { + return true + } + } + return false +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/flags.go b/vendor/github.com/onsi/ginkgo/v2/types/flags.go new file mode 100644 index 00000000000..9186ae873d0 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/flags.go @@ -0,0 +1,489 @@ +package types + +import ( + "flag" + "fmt" + "io" + "reflect" + "strings" + "time" + + "github.com/onsi/ginkgo/v2/formatter" +) + +type GinkgoFlag struct { + Name string + KeyPath string + SectionKey string + + Usage string + UsageArgument string + UsageDefaultValue string + + DeprecatedName string + DeprecatedDocLink string + DeprecatedVersion string + + ExportAs string +} + +type GinkgoFlags []GinkgoFlag + +func (f GinkgoFlags) CopyAppend(flags ...GinkgoFlag) GinkgoFlags { + out := GinkgoFlags{} + out = append(out, f...) + out = append(out, flags...) + return out +} + +func (f GinkgoFlags) WithPrefix(prefix string) GinkgoFlags { + if prefix == "" { + return f + } + out := GinkgoFlags{} + for _, flag := range f { + if flag.Name != "" { + flag.Name = prefix + "." + flag.Name + } + if flag.DeprecatedName != "" { + flag.DeprecatedName = prefix + "." + flag.DeprecatedName + } + if flag.ExportAs != "" { + flag.ExportAs = prefix + "." + flag.ExportAs + } + out = append(out, flag) + } + return out +} + +func (f GinkgoFlags) SubsetWithNames(names ...string) GinkgoFlags { + out := GinkgoFlags{} + for _, flag := range f { + for _, name := range names { + if flag.Name == name { + out = append(out, flag) + break + } + } + } + return out +} + +type GinkgoFlagSection struct { + Key string + Style string + Succinct bool + Heading string + Description string +} + +type GinkgoFlagSections []GinkgoFlagSection + +func (gfs GinkgoFlagSections) Lookup(key string) (GinkgoFlagSection, bool) { + for _, section := range gfs { + if section.Key == key { + return section, true + } + } + + return GinkgoFlagSection{}, false +} + +type GinkgoFlagSet struct { + flags GinkgoFlags + bindings interface{} + + sections GinkgoFlagSections + extraGoFlagsSection GinkgoFlagSection + + flagSet *flag.FlagSet +} + +// Call NewGinkgoFlagSet to create GinkgoFlagSet that creates and binds to it's own *flag.FlagSet +func NewGinkgoFlagSet(flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections) (GinkgoFlagSet, error) { + return bindFlagSet(GinkgoFlagSet{ + flags: flags, + bindings: bindings, + sections: sections, + }, nil) +} + +// Call NewGinkgoFlagSet to create GinkgoFlagSet that extends an existing *flag.FlagSet +func NewAttachedGinkgoFlagSet(flagSet *flag.FlagSet, flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections, extraGoFlagsSection GinkgoFlagSection) (GinkgoFlagSet, error) { + return bindFlagSet(GinkgoFlagSet{ + flags: flags, + bindings: bindings, + sections: sections, + extraGoFlagsSection: extraGoFlagsSection, + }, flagSet) +} + +func bindFlagSet(f GinkgoFlagSet, flagSet *flag.FlagSet) (GinkgoFlagSet, error) { + if flagSet == nil { + f.flagSet = flag.NewFlagSet("", flag.ContinueOnError) + //suppress all output as Ginkgo is responsible for formatting usage + f.flagSet.SetOutput(io.Discard) + } else { + f.flagSet = flagSet + //we're piggybacking on an existing flagset (typically go test) so we have limited control + //on user feedback + f.flagSet.Usage = f.substituteUsage + } + + for _, flag := range f.flags { + name := flag.Name + + deprecatedUsage := "[DEPRECATED]" + deprecatedName := flag.DeprecatedName + if name != "" { + deprecatedUsage = fmt.Sprintf("[DEPRECATED] use --%s instead", name) + } else if flag.Usage != "" { + deprecatedUsage += " " + flag.Usage + } + + value, ok := valueAtKeyPath(f.bindings, flag.KeyPath) + if !ok { + return GinkgoFlagSet{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath) + } + + iface, addr := value.Interface(), value.Addr().Interface() + + switch value.Type() { + case reflect.TypeOf(string("")): + if name != "" { + f.flagSet.StringVar(addr.(*string), name, iface.(string), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.StringVar(addr.(*string), deprecatedName, iface.(string), deprecatedUsage) + } + case reflect.TypeOf(int64(0)): + if name != "" { + f.flagSet.Int64Var(addr.(*int64), name, iface.(int64), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.Int64Var(addr.(*int64), deprecatedName, iface.(int64), deprecatedUsage) + } + case reflect.TypeOf(float64(0)): + if name != "" { + f.flagSet.Float64Var(addr.(*float64), name, iface.(float64), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.Float64Var(addr.(*float64), deprecatedName, iface.(float64), deprecatedUsage) + } + case reflect.TypeOf(int(0)): + if name != "" { + f.flagSet.IntVar(addr.(*int), name, iface.(int), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.IntVar(addr.(*int), deprecatedName, iface.(int), deprecatedUsage) + } + case reflect.TypeOf(bool(true)): + if name != "" { + f.flagSet.BoolVar(addr.(*bool), name, iface.(bool), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.BoolVar(addr.(*bool), deprecatedName, iface.(bool), deprecatedUsage) + } + case reflect.TypeOf(time.Duration(0)): + if name != "" { + f.flagSet.DurationVar(addr.(*time.Duration), name, iface.(time.Duration), flag.Usage) + } + if deprecatedName != "" { + f.flagSet.DurationVar(addr.(*time.Duration), deprecatedName, iface.(time.Duration), deprecatedUsage) + } + + case reflect.TypeOf([]string{}): + if name != "" { + f.flagSet.Var(stringSliceVar{value}, name, flag.Usage) + } + if deprecatedName != "" { + f.flagSet.Var(stringSliceVar{value}, deprecatedName, deprecatedUsage) + } + default: + return GinkgoFlagSet{}, fmt.Errorf("unsupported type %T", iface) + } + } + + return f, nil +} + +func (f GinkgoFlagSet) IsZero() bool { + return f.flagSet == nil +} + +func (f GinkgoFlagSet) WasSet(name string) bool { + found := false + f.flagSet.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + + return found +} + +func (f GinkgoFlagSet) Lookup(name string) *flag.Flag { + return f.flagSet.Lookup(name) +} + +func (f GinkgoFlagSet) Parse(args []string) ([]string, error) { + if f.IsZero() { + return args, nil + } + err := f.flagSet.Parse(args) + if err != nil { + return []string{}, err + } + return f.flagSet.Args(), nil +} + +func (f GinkgoFlagSet) ValidateDeprecations(deprecationTracker *DeprecationTracker) { + if f.IsZero() { + return + } + f.flagSet.Visit(func(flag *flag.Flag) { + for _, ginkgoFlag := range f.flags { + if ginkgoFlag.DeprecatedName != "" && strings.HasSuffix(flag.Name, ginkgoFlag.DeprecatedName) { + message := fmt.Sprintf("--%s is deprecated", ginkgoFlag.DeprecatedName) + if ginkgoFlag.Name != "" { + message = fmt.Sprintf("--%s is deprecated, use --%s instead", ginkgoFlag.DeprecatedName, ginkgoFlag.Name) + } else if ginkgoFlag.Usage != "" { + message += " " + ginkgoFlag.Usage + } + + deprecationTracker.TrackDeprecation(Deprecation{ + Message: message, + DocLink: ginkgoFlag.DeprecatedDocLink, + Version: ginkgoFlag.DeprecatedVersion, + }) + } + } + }) +} + +func (f GinkgoFlagSet) Usage() string { + if f.IsZero() { + return "" + } + groupedFlags := map[GinkgoFlagSection]GinkgoFlags{} + ungroupedFlags := GinkgoFlags{} + managedFlags := map[string]bool{} + extraGoFlags := []*flag.Flag{} + + for _, flag := range f.flags { + managedFlags[flag.Name] = true + managedFlags[flag.DeprecatedName] = true + + if flag.Name == "" { + continue + } + + section, ok := f.sections.Lookup(flag.SectionKey) + if ok { + groupedFlags[section] = append(groupedFlags[section], flag) + } else { + ungroupedFlags = append(ungroupedFlags, flag) + } + } + + f.flagSet.VisitAll(func(flag *flag.Flag) { + if !managedFlags[flag.Name] { + extraGoFlags = append(extraGoFlags, flag) + } + }) + + out := "" + for _, section := range f.sections { + flags := groupedFlags[section] + if len(flags) == 0 { + continue + } + out += f.usageForSection(section) + if section.Succinct { + succinctFlags := []string{} + for _, flag := range flags { + if flag.Name != "" { + succinctFlags = append(succinctFlags, fmt.Sprintf("--%s", flag.Name)) + } + } + out += formatter.Fiw(1, formatter.COLS, section.Style+strings.Join(succinctFlags, ", ")+"{{/}}\n") + } else { + for _, flag := range flags { + out += f.usageForFlag(flag, section.Style) + } + } + out += "\n" + } + if len(ungroupedFlags) > 0 { + for _, flag := range ungroupedFlags { + out += f.usageForFlag(flag, "") + } + out += "\n" + } + if len(extraGoFlags) > 0 { + out += f.usageForSection(f.extraGoFlagsSection) + for _, goFlag := range extraGoFlags { + out += f.usageForGoFlag(goFlag) + } + } + + return out +} + +func (f GinkgoFlagSet) substituteUsage() { + fmt.Fprintln(f.flagSet.Output(), f.Usage()) +} + +func valueAtKeyPath(root interface{}, keyPath string) (reflect.Value, bool) { + if len(keyPath) == 0 { + return reflect.Value{}, false + } + + val := reflect.ValueOf(root) + components := strings.Split(keyPath, ".") + for _, component := range components { + val = reflect.Indirect(val) + switch val.Kind() { + case reflect.Map: + val = val.MapIndex(reflect.ValueOf(component)) + if val.Kind() == reflect.Interface { + val = reflect.ValueOf(val.Interface()) + } + case reflect.Struct: + val = val.FieldByName(component) + default: + return reflect.Value{}, false + } + if (val == reflect.Value{}) { + return reflect.Value{}, false + } + } + + return val, true +} + +func (f GinkgoFlagSet) usageForSection(section GinkgoFlagSection) string { + out := formatter.F(section.Style + "{{bold}}{{underline}}" + section.Heading + "{{/}}\n") + if section.Description != "" { + out += formatter.Fiw(0, formatter.COLS, section.Description+"\n") + } + return out +} + +func (f GinkgoFlagSet) usageForFlag(flag GinkgoFlag, style string) string { + argument := flag.UsageArgument + defValue := flag.UsageDefaultValue + if argument == "" { + value, _ := valueAtKeyPath(f.bindings, flag.KeyPath) + switch value.Type() { + case reflect.TypeOf(string("")): + argument = "string" + case reflect.TypeOf(int64(0)), reflect.TypeOf(int(0)): + argument = "int" + case reflect.TypeOf(time.Duration(0)): + argument = "duration" + case reflect.TypeOf(float64(0)): + argument = "float" + case reflect.TypeOf([]string{}): + argument = "string" + } + } + if argument != "" { + argument = "[" + argument + "] " + } + if defValue != "" { + defValue = fmt.Sprintf("(default: %s)", defValue) + } + hyphens := "--" + if len(flag.Name) == 1 { + hyphens = "-" + } + + out := formatter.Fi(1, style+"%s%s{{/}} %s{{gray}}%s{{/}}\n", hyphens, flag.Name, argument, defValue) + out += formatter.Fiw(2, formatter.COLS, "{{light-gray}}%s{{/}}\n", flag.Usage) + return out +} + +func (f GinkgoFlagSet) usageForGoFlag(goFlag *flag.Flag) string { + //Taken directly from the flag package + out := fmt.Sprintf(" -%s", goFlag.Name) + name, usage := flag.UnquoteUsage(goFlag) + if len(name) > 0 { + out += " " + name + } + if len(out) <= 4 { + out += "\t" + } else { + out += "\n \t" + } + out += strings.ReplaceAll(usage, "\n", "\n \t") + out += "\n" + return out +} + +type stringSliceVar struct { + slice reflect.Value +} + +func (ssv stringSliceVar) String() string { return "" } +func (ssv stringSliceVar) Set(s string) error { + ssv.slice.Set(reflect.AppendSlice(ssv.slice, reflect.ValueOf([]string{s}))) + return nil +} + +//given a set of GinkgoFlags and bindings, generate flag arguments suitable to be passed to an application with that set of flags configured. +func GenerateFlagArgs(flags GinkgoFlags, bindings interface{}) ([]string, error) { + result := []string{} + for _, flag := range flags { + name := flag.ExportAs + if name == "" { + name = flag.Name + } + if name == "" { + continue + } + + value, ok := valueAtKeyPath(bindings, flag.KeyPath) + if !ok { + return []string{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath) + } + + iface := value.Interface() + switch value.Type() { + case reflect.TypeOf(string("")): + if iface.(string) != "" { + result = append(result, fmt.Sprintf("--%s=%s", name, iface)) + } + case reflect.TypeOf(int64(0)): + if iface.(int64) != 0 { + result = append(result, fmt.Sprintf("--%s=%d", name, iface)) + } + case reflect.TypeOf(float64(0)): + if iface.(float64) != 0 { + result = append(result, fmt.Sprintf("--%s=%f", name, iface)) + } + case reflect.TypeOf(int(0)): + if iface.(int) != 0 { + result = append(result, fmt.Sprintf("--%s=%d", name, iface)) + } + case reflect.TypeOf(bool(true)): + if iface.(bool) { + result = append(result, fmt.Sprintf("--%s", name)) + } + case reflect.TypeOf(time.Duration(0)): + if iface.(time.Duration) != time.Duration(0) { + result = append(result, fmt.Sprintf("--%s=%s", name, iface)) + } + + case reflect.TypeOf([]string{}): + strings := iface.([]string) + for _, s := range strings { + result = append(result, fmt.Sprintf("--%s=%s", name, s)) + } + default: + return []string{}, fmt.Errorf("unsupported type %T", iface) + } + } + + return result, nil +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go b/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go new file mode 100644 index 00000000000..0403f9e6319 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go @@ -0,0 +1,347 @@ +package types + +import ( + "fmt" + "regexp" + "strings" +) + +var DEBUG_LABEL_FILTER_PARSING = false + +type LabelFilter func([]string) bool + +func matchLabelAction(label string) LabelFilter { + expected := strings.ToLower(label) + return func(labels []string) bool { + for i := range labels { + if strings.ToLower(labels[i]) == expected { + return true + } + } + return false + } +} + +func matchLabelRegexAction(regex *regexp.Regexp) LabelFilter { + return func(labels []string) bool { + for i := range labels { + if regex.MatchString(labels[i]) { + return true + } + } + return false + } +} + +func notAction(filter LabelFilter) LabelFilter { + return func(labels []string) bool { return !filter(labels) } +} + +func andAction(a, b LabelFilter) LabelFilter { + return func(labels []string) bool { return a(labels) && b(labels) } +} + +func orAction(a, b LabelFilter) LabelFilter { + return func(labels []string) bool { return a(labels) || b(labels) } +} + +type lfToken uint + +const ( + lfTokenInvalid lfToken = iota + + lfTokenRoot + lfTokenOpenGroup + lfTokenCloseGroup + lfTokenNot + lfTokenAnd + lfTokenOr + lfTokenRegexp + lfTokenLabel + lfTokenEOF +) + +func (l lfToken) Precedence() int { + switch l { + case lfTokenRoot, lfTokenOpenGroup: + return 0 + case lfTokenOr: + return 1 + case lfTokenAnd: + return 2 + case lfTokenNot: + return 3 + } + return -1 +} + +func (l lfToken) String() string { + switch l { + case lfTokenRoot: + return "ROOT" + case lfTokenOpenGroup: + return "(" + case lfTokenCloseGroup: + return ")" + case lfTokenNot: + return "!" + case lfTokenAnd: + return "&&" + case lfTokenOr: + return "||" + case lfTokenRegexp: + return "/regexp/" + case lfTokenLabel: + return "label" + case lfTokenEOF: + return "EOF" + } + return "INVALID" +} + +type treeNode struct { + token lfToken + location int + value string + + parent *treeNode + leftNode *treeNode + rightNode *treeNode +} + +func (tn *treeNode) setRightNode(node *treeNode) { + tn.rightNode = node + node.parent = tn +} + +func (tn *treeNode) setLeftNode(node *treeNode) { + tn.leftNode = node + node.parent = tn +} + +func (tn *treeNode) firstAncestorWithPrecedenceLEQ(precedence int) *treeNode { + if tn.token.Precedence() <= precedence { + return tn + } + return tn.parent.firstAncestorWithPrecedenceLEQ(precedence) +} + +func (tn *treeNode) firstUnmatchedOpenNode() *treeNode { + if tn.token == lfTokenOpenGroup { + return tn + } + if tn.parent == nil { + return nil + } + return tn.parent.firstUnmatchedOpenNode() +} + +func (tn *treeNode) constructLabelFilter(input string) (LabelFilter, error) { + switch tn.token { + case lfTokenOpenGroup: + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, tn.location, "Mismatched '(' - could not find matching ')'.") + case lfTokenLabel: + return matchLabelAction(tn.value), nil + case lfTokenRegexp: + re, err := regexp.Compile(tn.value) + if err != nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, tn.location, fmt.Sprintf("RegExp compilation error: %s", err)) + } + return matchLabelRegexAction(re), nil + } + + if tn.rightNode == nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, -1, "Unexpected EOF.") + } + rightLF, err := tn.rightNode.constructLabelFilter(input) + if err != nil { + return nil, err + } + + switch tn.token { + case lfTokenRoot, lfTokenCloseGroup: + return rightLF, nil + case lfTokenNot: + return notAction(rightLF), nil + } + + if tn.leftNode == nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, tn.location, fmt.Sprintf("Malformed tree - '%s' is missing left operand.", tn.token)) + } + leftLF, err := tn.leftNode.constructLabelFilter(input) + if err != nil { + return nil, err + } + + switch tn.token { + case lfTokenAnd: + return andAction(leftLF, rightLF), nil + case lfTokenOr: + return orAction(leftLF, rightLF), nil + } + + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, tn.location, fmt.Sprintf("Invalid token '%s'.", tn.token)) +} + +func (tn *treeNode) tokenString() string { + out := fmt.Sprintf("<%s", tn.token) + if tn.value != "" { + out += " | " + tn.value + } + out += ">" + return out +} + +func (tn *treeNode) toString(indent int) string { + out := tn.tokenString() + "\n" + if tn.leftNode != nil { + out += fmt.Sprintf("%s |_(L)_%s", strings.Repeat(" ", indent), tn.leftNode.toString(indent+1)) + } + if tn.rightNode != nil { + out += fmt.Sprintf("%s |_(R)_%s", strings.Repeat(" ", indent), tn.rightNode.toString(indent+1)) + } + return out +} + +func tokenize(input string) func() (*treeNode, error) { + runes, i := []rune(input), 0 + + peekIs := func(r rune) bool { + if i+1 < len(runes) { + return runes[i+1] == r + } + return false + } + + consumeUntil := func(cutset string) (string, int) { + j := i + for ; j < len(runes); j++ { + if strings.IndexRune(cutset, runes[j]) >= 0 { + break + } + } + return string(runes[i:j]), j - i + } + + return func() (*treeNode, error) { + for i < len(runes) && runes[i] == ' ' { + i += 1 + } + + if i >= len(runes) { + return &treeNode{token: lfTokenEOF}, nil + } + + node := &treeNode{location: i} + switch runes[i] { + case '&': + if !peekIs('&') { + return &treeNode{}, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, i, "Invalid token '&'. Did you mean '&&'?") + } + i += 2 + node.token = lfTokenAnd + case '|': + if !peekIs('|') { + return &treeNode{}, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, i, "Invalid token '|'. Did you mean '||'?") + } + i += 2 + node.token = lfTokenOr + case '!': + i += 1 + node.token = lfTokenNot + case ',': + i += 1 + node.token = lfTokenOr + case '(': + i += 1 + node.token = lfTokenOpenGroup + case ')': + i += 1 + node.token = lfTokenCloseGroup + case '/': + i += 1 + value, n := consumeUntil("/") + i += n + 1 + node.token, node.value = lfTokenRegexp, value + default: + value, n := consumeUntil("&|!,()/") + i += n + node.token, node.value = lfTokenLabel, strings.TrimSpace(value) + } + return node, nil + } +} + +func ParseLabelFilter(input string) (LabelFilter, error) { + if DEBUG_LABEL_FILTER_PARSING { + fmt.Println("\n==============") + fmt.Println("Input: ", input) + fmt.Print("Tokens: ") + } + nextToken := tokenize(input) + + root := &treeNode{token: lfTokenRoot} + current := root +LOOP: + for { + node, err := nextToken() + if err != nil { + return nil, err + } + + if DEBUG_LABEL_FILTER_PARSING { + fmt.Print(node.tokenString() + " ") + } + + switch node.token { + case lfTokenEOF: + break LOOP + case lfTokenLabel, lfTokenRegexp: + if current.rightNode != nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, "Found two adjacent labels. You need an operator between them.") + } + current.setRightNode(node) + case lfTokenNot, lfTokenOpenGroup: + if current.rightNode != nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, fmt.Sprintf("Invalid token '%s'.", node.token)) + } + current.setRightNode(node) + current = node + case lfTokenAnd, lfTokenOr: + if current.rightNode == nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, fmt.Sprintf("Operator '%s' missing left hand operand.", node.token)) + } + nodeToStealFrom := current.firstAncestorWithPrecedenceLEQ(node.token.Precedence()) + node.setLeftNode(nodeToStealFrom.rightNode) + nodeToStealFrom.setRightNode(node) + current = node + case lfTokenCloseGroup: + firstUnmatchedOpenNode := current.firstUnmatchedOpenNode() + if firstUnmatchedOpenNode == nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, "Mismatched ')' - could not find matching '('.") + } + if firstUnmatchedOpenNode == current && current.rightNode == nil { + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, "Found empty '()' group.") + } + firstUnmatchedOpenNode.token = lfTokenCloseGroup //signify the group is now closed + current = firstUnmatchedOpenNode.parent + default: + return nil, GinkgoErrors.SyntaxErrorParsingLabelFilter(input, node.location, fmt.Sprintf("Unknown token '%s'.", node.token)) + } + } + if DEBUG_LABEL_FILTER_PARSING { + fmt.Printf("\n Tree:\n%s", root.toString(0)) + } + return root.constructLabelFilter(input) +} + +func ValidateAndCleanupLabel(label string, cl CodeLocation) (string, error) { + out := strings.TrimSpace(label) + if out == "" { + return "", GinkgoErrors.InvalidEmptyLabel(cl) + } + if strings.ContainsAny(out, "&|!,()/") { + return "", GinkgoErrors.InvalidLabel(label, cl) + } + return out, nil +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go b/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go new file mode 100644 index 00000000000..798bedc0395 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go @@ -0,0 +1,186 @@ +package types + +import ( + "encoding/json" + "fmt" + "time" +) + +//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports +//and across the network connection when running in parallel +type ReportEntryValue struct { + raw interface{} //unexported to prevent gob from freaking out about unregistered structs + AsJSON string + Representation string +} + +func WrapEntryValue(value interface{}) ReportEntryValue { + return ReportEntryValue{ + raw: value, + } +} + +func (rev ReportEntryValue) GetRawValue() interface{} { + return rev.raw +} + +func (rev ReportEntryValue) String() string { + if rev.raw == nil { + return "" + } + if colorableStringer, ok := rev.raw.(ColorableStringer); ok { + return colorableStringer.ColorableString() + } + + if stringer, ok := rev.raw.(fmt.Stringer); ok { + return stringer.String() + } + if rev.Representation != "" { + return rev.Representation + } + return fmt.Sprintf("%+v", rev.raw) +} + +func (rev ReportEntryValue) MarshalJSON() ([]byte, error) { + //All this to capture the representation at encoding-time, not creating time + //This way users can Report on pointers and get their final values at reporting-time + out := struct { + AsJSON string + Representation string + }{ + Representation: rev.String(), + } + asJSON, err := json.Marshal(rev.raw) + if err != nil { + return nil, err + } + out.AsJSON = string(asJSON) + + return json.Marshal(out) +} + +func (rev *ReportEntryValue) UnmarshalJSON(data []byte) error { + in := struct { + AsJSON string + Representation string + }{} + err := json.Unmarshal(data, &in) + if err != nil { + return err + } + rev.AsJSON = in.AsJSON + rev.Representation = in.Representation + return json.Unmarshal([]byte(in.AsJSON), &(rev.raw)) +} + +func (rev ReportEntryValue) GobEncode() ([]byte, error) { + return rev.MarshalJSON() +} + +func (rev *ReportEntryValue) GobDecode(data []byte) error { + return rev.UnmarshalJSON(data) +} + +// ReportEntry captures information attached to `SpecReport` via `AddReportEntry` +type ReportEntry struct { + // Visibility captures the visibility policy for this ReportEntry + Visibility ReportEntryVisibility + // Time captures the time the AddReportEntry was called + Time time.Time + // Location captures the location of the AddReportEntry call + Location CodeLocation + // Name captures the name of this report + Name string + // Value captures the (optional) object passed into AddReportEntry - this can be + // anything the user wants. The value passed to AddReportEntry is wrapped in a ReportEntryValue to make + // encoding/decoding the value easier. To access the raw value call entry.GetRawValue() + Value ReportEntryValue +} + +// ColorableStringer is an interface that ReportEntry values can satisfy. If they do then ColorableString() is used to generate their representation. +type ColorableStringer interface { + ColorableString() string +} + +// StringRepresentation() returns the string representation of the value associated with the ReportEntry -- +// if value is nil, empty string is returned +// if value is a `ColorableStringer` then `Value.ColorableString()` is returned +// if value is a `fmt.Stringer` then `Value.String()` is returned +// otherwise the value is formatted with "%+v" +func (entry ReportEntry) StringRepresentation() string { + return entry.Value.String() +} + +// GetRawValue returns the Value object that was passed to AddReportEntry +// If called in-process this will be the same object that was passed into AddReportEntry. +// If used from a rehydrated JSON file _or_ in a ReportAfterSuite when running in parallel this will be +// a JSON-decoded {}interface. If you want to reconstitute your original object you can decode the entry.Value.AsJSON +// field yourself. +func (entry ReportEntry) GetRawValue() interface{} { + return entry.Value.GetRawValue() +} + + + +type ReportEntries []ReportEntry + +func (re ReportEntries) HasVisibility(visibilities ...ReportEntryVisibility) bool { + for _, entry := range re { + if entry.Visibility.Is(visibilities...) { + return true + } + } + return false +} + +func (re ReportEntries) WithVisibility(visibilities ...ReportEntryVisibility) ReportEntries { + out := ReportEntries{} + + for _, entry := range re { + if entry.Visibility.Is(visibilities...) { + out = append(out, entry) + } + } + + return out +} + +// ReportEntryVisibility governs the visibility of ReportEntries in Ginkgo's console reporter +type ReportEntryVisibility uint + +const ( + // Always print out this ReportEntry + ReportEntryVisibilityAlways ReportEntryVisibility = iota + // Only print out this ReportEntry if the spec fails or if the test is run with -v + ReportEntryVisibilityFailureOrVerbose + // Never print out this ReportEntry (note that ReportEntrys are always encoded in machine readable reports (e.g. JSON, JUnit, etc.)) + ReportEntryVisibilityNever +) + +var revEnumSupport = NewEnumSupport(map[uint]string{ + uint(ReportEntryVisibilityAlways): "always", + uint(ReportEntryVisibilityFailureOrVerbose): "failure-or-verbose", + uint(ReportEntryVisibilityNever): "never", +}) + +func (rev ReportEntryVisibility) String() string { + return revEnumSupport.String(uint(rev)) +} +func (rev *ReportEntryVisibility) UnmarshalJSON(b []byte) error { + out, err := revEnumSupport.UnmarshJSON(b) + *rev = ReportEntryVisibility(out) + return err +} +func (rev ReportEntryVisibility) MarshalJSON() ([]byte, error) { + return revEnumSupport.MarshJSON(uint(rev)) +} + +func (v ReportEntryVisibility) Is(visibilities ...ReportEntryVisibility) bool { + for _, visibility := range visibilities { + if v == visibility { + return true + } + } + + return false +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/types.go b/vendor/github.com/onsi/ginkgo/v2/types/types.go new file mode 100644 index 00000000000..9fc4425fe7e --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/types.go @@ -0,0 +1,696 @@ +package types + +import ( + "encoding/json" + "strings" + "time" +) + +const GINKGO_FOCUS_EXIT_CODE = 197 +const GINKGO_TIME_FORMAT = "01/02/06 15:04:05.999" + +// Report captures information about a Ginkgo test run +type Report struct { + //SuitePath captures the absolute path to the test suite + SuitePath string + + //SuiteDescription captures the description string passed to the DSL's RunSpecs() function + SuiteDescription string + + //SuiteLabels captures any labels attached to the suite by the DSL's RunSpecs() function + SuiteLabels []string + + //SuiteSucceeded captures the success or failure status of the test run + //If true, the test run is considered successful. + //If false, the test run is considered unsuccessful + SuiteSucceeded bool + + //SuiteHasProgrammaticFocus captures whether the test suite has a test or set of tests that are programmatically focused + //(i.e an `FIt` or an `FDescribe` + SuiteHasProgrammaticFocus bool + + //SpecialSuiteFailureReasons may contain special failure reasons + //For example, a test suite might be considered "failed" even if none of the individual specs + //have a failure state. For example, if the user has configured --fail-on-pending the test suite + //will have failed if there are pending tests even though all non-pending tests may have passed. In such + //cases, Ginkgo populates SpecialSuiteFailureReasons with a clear message indicating the reason for the failure. + //SpecialSuiteFailureReasons is also populated if the test suite is interrupted by the user. + //Since multiple special failure reasons can occur, this field is a slice. + SpecialSuiteFailureReasons []string + + //PreRunStats contains a set of stats captured before the test run begins. This is primarily used + //by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) + //and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. + PreRunStats PreRunStats + + //StartTime and EndTime capture the start and end time of the test run + StartTime time.Time + EndTime time.Time + + //RunTime captures the duration of the test run + RunTime time.Duration + + //SuiteConfig captures the Ginkgo configuration governing this test run + //SuiteConfig includes information necessary for reproducing an identical test run, + //such as the random seed and any filters applied during the test run + SuiteConfig SuiteConfig + + //SpecReports is a list of all SpecReports generated by this test run + SpecReports SpecReports +} + +//PreRunStats contains a set of stats captured before the test run begins. This is primarily used +//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) +//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. +type PreRunStats struct { + TotalSpecs int + SpecsThatWillRun int +} + +//Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes +//to form a complete final report. +func (report Report) Add(other Report) Report { + report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded + + if other.StartTime.Before(report.StartTime) { + report.StartTime = other.StartTime + } + + if other.EndTime.After(report.EndTime) { + report.EndTime = other.EndTime + } + + specialSuiteFailureReasons := []string{} + reasonsLookup := map[string]bool{} + for _, reasons := range [][]string{report.SpecialSuiteFailureReasons, other.SpecialSuiteFailureReasons} { + for _, reason := range reasons { + if !reasonsLookup[reason] { + reasonsLookup[reason] = true + specialSuiteFailureReasons = append(specialSuiteFailureReasons, reason) + } + } + } + report.SpecialSuiteFailureReasons = specialSuiteFailureReasons + report.RunTime = report.EndTime.Sub(report.StartTime) + + reports := make(SpecReports, len(report.SpecReports)+len(other.SpecReports)) + for i := range report.SpecReports { + reports[i] = report.SpecReports[i] + } + offset := len(report.SpecReports) + for i := range other.SpecReports { + reports[i+offset] = other.SpecReports[i] + } + + report.SpecReports = reports + return report +} + +// SpecReport captures information about a Ginkgo spec. +type SpecReport struct { + // ContainerHierarchyTexts is a slice containing the text strings of + // all Describe/Context/When containers in this spec's hierarchy. + ContainerHierarchyTexts []string + + // ContainerHierarchyLocations is a slice containing the CodeLocations of + // all Describe/Context/When containers in this spec's hierarchy. + ContainerHierarchyLocations []CodeLocation + + // ContainerHierarchyLabels is a slice containing the labels of + // all Describe/Context/When containers in this spec's hierarchy + ContainerHierarchyLabels [][]string + + // LeafNodeType, LeadNodeLocation, LeafNodeLabels and LeafNodeText capture the NodeType, CodeLocation, and text + // of the Ginkgo node being tested (typically an NodeTypeIt node, though this can also be + // one of the NodeTypesForSuiteLevelNodes node types) + LeafNodeType NodeType + LeafNodeLocation CodeLocation + LeafNodeLabels []string + LeafNodeText string + + // State captures whether the spec has passed, failed, etc. + State SpecState + + // IsSerial captures whether the spec has the Serial decorator + IsSerial bool + + // IsInOrderedContainer captures whether the spec appears in an Ordered container + IsInOrderedContainer bool + + // StartTime and EndTime capture the start and end time of the spec + StartTime time.Time + EndTime time.Time + + // RunTime captures the duration of the spec + RunTime time.Duration + + // ParallelProcess captures the parallel process that this spec ran on + ParallelProcess int + + //Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip()) + //It includes detailed information about the Failure + Failure Failure + + // NumAttempts captures the number of times this Spec was run. + // Flakey specs can be retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator. + // Repeated specs can be retried with the use of the MustPassRepeatedly decorator + NumAttempts int + + // MaxFlakeAttempts captures whether the spec has been retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator. + MaxFlakeAttempts int + + // MaxMustPassRepeatedly captures whether the spec has the MustPassRepeatedly decorator + MaxMustPassRepeatedly int + + // CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter + CapturedGinkgoWriterOutput string + + // CapturedStdOutErr contains text printed to stdout/stderr (when running in parallel) + // This is always empty when running in series or calling CurrentSpecReport() + // It is used internally by Ginkgo's reporter + CapturedStdOutErr string + + // ReportEntries contains any reports added via `AddReportEntry` + ReportEntries ReportEntries + + // ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator + ProgressReports []ProgressReport + + // AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode. + AdditionalFailures []AdditionalFailure +} + +func (report SpecReport) MarshalJSON() ([]byte, error) { + //All this to avoid emitting an empty Failure struct in the JSON + out := struct { + ContainerHierarchyTexts []string + ContainerHierarchyLocations []CodeLocation + ContainerHierarchyLabels [][]string + LeafNodeType NodeType + LeafNodeLocation CodeLocation + LeafNodeLabels []string + LeafNodeText string + State SpecState + StartTime time.Time + EndTime time.Time + RunTime time.Duration + ParallelProcess int + Failure *Failure `json:",omitempty"` + NumAttempts int + MaxFlakeAttempts int + MaxMustPassRepeatedly int + CapturedGinkgoWriterOutput string `json:",omitempty"` + CapturedStdOutErr string `json:",omitempty"` + ReportEntries ReportEntries `json:",omitempty"` + ProgressReports []ProgressReport `json:",omitempty"` + AdditionalFailures []AdditionalFailure `json:",omitempty"` + }{ + ContainerHierarchyTexts: report.ContainerHierarchyTexts, + ContainerHierarchyLocations: report.ContainerHierarchyLocations, + ContainerHierarchyLabels: report.ContainerHierarchyLabels, + LeafNodeType: report.LeafNodeType, + LeafNodeLocation: report.LeafNodeLocation, + LeafNodeLabels: report.LeafNodeLabels, + LeafNodeText: report.LeafNodeText, + State: report.State, + StartTime: report.StartTime, + EndTime: report.EndTime, + RunTime: report.RunTime, + ParallelProcess: report.ParallelProcess, + Failure: nil, + ReportEntries: nil, + NumAttempts: report.NumAttempts, + MaxFlakeAttempts: report.MaxFlakeAttempts, + MaxMustPassRepeatedly: report.MaxMustPassRepeatedly, + CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput, + CapturedStdOutErr: report.CapturedStdOutErr, + } + + if !report.Failure.IsZero() { + out.Failure = &(report.Failure) + } + if len(report.ReportEntries) > 0 { + out.ReportEntries = report.ReportEntries + } + if len(report.ProgressReports) > 0 { + out.ProgressReports = report.ProgressReports + } + if len(report.AdditionalFailures) > 0 { + out.AdditionalFailures = report.AdditionalFailures + } + + return json.Marshal(out) +} + +// CombinedOutput returns a single string representation of both CapturedStdOutErr and CapturedGinkgoWriterOutput +// Note that both are empty when using CurrentSpecReport() so CurrentSpecReport().CombinedOutput() will always be empty. +// CombinedOutput() is used internally by Ginkgo's reporter. +func (report SpecReport) CombinedOutput() string { + if report.CapturedStdOutErr == "" { + return report.CapturedGinkgoWriterOutput + } + if report.CapturedGinkgoWriterOutput == "" { + return report.CapturedStdOutErr + } + return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput +} + +//Failed returns true if report.State is one of the SpecStateFailureStates +// (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted) +func (report SpecReport) Failed() bool { + return report.State.Is(SpecStateFailureStates) +} + +//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText +func (report SpecReport) FullText() string { + texts := []string{} + texts = append(texts, report.ContainerHierarchyTexts...) + if report.LeafNodeText != "" { + texts = append(texts, report.LeafNodeText) + } + return strings.Join(texts, " ") +} + +//Labels returns a deduped set of all the spec's Labels. +func (report SpecReport) Labels() []string { + out := []string{} + seen := map[string]bool{} + for _, labels := range report.ContainerHierarchyLabels { + for _, label := range labels { + if !seen[label] { + seen[label] = true + out = append(out, label) + } + } + } + for _, label := range report.LeafNodeLabels { + if !seen[label] { + seen[label] = true + out = append(out, label) + } + } + + return out +} + +//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query +func (report SpecReport) MatchesLabelFilter(query string) (bool, error) { + filter, err := ParseLabelFilter(query) + if err != nil { + return false, err + } + return filter(report.Labels()), nil +} + +//FileName() returns the name of the file containing the spec +func (report SpecReport) FileName() string { + return report.LeafNodeLocation.FileName +} + +//LineNumber() returns the line number of the leaf node +func (report SpecReport) LineNumber() int { + return report.LeafNodeLocation.LineNumber +} + +//FailureMessage() returns the failure message (or empty string if the test hasn't failed) +func (report SpecReport) FailureMessage() string { + return report.Failure.Message +} + +//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed) +func (report SpecReport) FailureLocation() CodeLocation { + return report.Failure.Location +} + +type SpecReports []SpecReport + +//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes +func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports { + count := 0 + for i := range reports { + if reports[i].LeafNodeType.Is(nodeTypes) { + count++ + } + } + + out := make(SpecReports, count) + j := 0 + for i := range reports { + if reports[i].LeafNodeType.Is(nodeTypes) { + out[j] = reports[i] + j++ + } + } + return out +} + +//WithState returns the subset of SpecReports with State matching one of the requested SpecStates +func (reports SpecReports) WithState(states SpecState) SpecReports { + count := 0 + for i := range reports { + if reports[i].State.Is(states) { + count++ + } + } + + out, j := make(SpecReports, count), 0 + for i := range reports { + if reports[i].State.Is(states) { + out[j] = reports[i] + j++ + } + } + return out +} + +//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates +func (reports SpecReports) CountWithState(states SpecState) int { + n := 0 + for i := range reports { + if reports[i].State.Is(states) { + n += 1 + } + } + return n +} + +//If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts. +func (reports SpecReports) CountOfFlakedSpecs() int { + n := 0 + for i := range reports { + if reports[i].MaxFlakeAttempts > 1 && reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 { + n += 1 + } + } + return n +} + +//If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts +func (reports SpecReports) CountOfRepeatedSpecs() int { + n := 0 + for i := range reports { + if reports[i].MaxMustPassRepeatedly > 1 && reports[i].State.Is(SpecStateFailureStates) && reports[i].NumAttempts > 1 { + n += 1 + } + } + return n +} + +// Failure captures failure information for an individual test +type Failure struct { + // Message - the failure message passed into Fail(...). When using a matcher library + // like Gomega, this will contain the failure message generated by Gomega. + // + // Message is also populated if the user has called Skip(...). + Message string + + // Location - the CodeLocation where the failure occurred + // This CodeLocation will include a fully-populated StackTrace + Location CodeLocation + + // ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked) + // then ForwardedPanic will be populated with a string representation of the captured panic. + ForwardedPanic string `json:",omitempty"` + + // FailureNodeContext - one of three contexts describing the node in which the failure occurred: + // FailureNodeIsLeafNode means the failure occurred in the leaf node of the associated SpecReport. None of the other FailureNode fields will be populated + // FailureNodeAtTopLevel means the failure occurred in a non-leaf node that is defined at the top-level of the spec (i.e. not in a container). FailureNodeType and FailureNodeLocation will be populated. + // FailureNodeInContainer means the failure occurred in a non-leaf node that is defined within a container. FailureNodeType, FailureNodeLocation, and FailureNodeContainerIndex will be populated. + // + // FailureNodeType will contain the NodeType of the node in which the failure occurred. + // FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred. + // If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred. + FailureNodeContext FailureNodeContext + FailureNodeType NodeType + FailureNodeLocation CodeLocation + FailureNodeContainerIndex int + + //ProgressReport is populated if the spec was interrupted or timed out + ProgressReport ProgressReport +} + +func (f Failure) IsZero() bool { + return f.Message == "" && (f.Location == CodeLocation{}) +} + +// FailureNodeContext captures the location context for the node containing the failing line of code +type FailureNodeContext uint + +const ( + FailureNodeContextInvalid FailureNodeContext = iota + + FailureNodeIsLeafNode + FailureNodeAtTopLevel + FailureNodeInContainer +) + +var fncEnumSupport = NewEnumSupport(map[uint]string{ + uint(FailureNodeContextInvalid): "INVALID FAILURE NODE CONTEXT", + uint(FailureNodeIsLeafNode): "leaf-node", + uint(FailureNodeAtTopLevel): "top-level", + uint(FailureNodeInContainer): "in-container", +}) + +func (fnc FailureNodeContext) String() string { + return fncEnumSupport.String(uint(fnc)) +} +func (fnc *FailureNodeContext) UnmarshalJSON(b []byte) error { + out, err := fncEnumSupport.UnmarshJSON(b) + *fnc = FailureNodeContext(out) + return err +} +func (fnc FailureNodeContext) MarshalJSON() ([]byte, error) { + return fncEnumSupport.MarshJSON(uint(fnc)) +} + +// AdditionalFailure capturs any additional failures that occur after the initial failure of a psec +// these typically occur in clean up nodes after the spec has failed. +// We can't simply use Failure as we want to track the SpecState to know what kind of failure this is +type AdditionalFailure struct { + State SpecState + Failure Failure +} + +// SpecState captures the state of a spec +// To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)` +type SpecState uint + +const ( + SpecStateInvalid SpecState = 0 + + SpecStatePending SpecState = 1 << iota + SpecStateSkipped + SpecStatePassed + SpecStateFailed + SpecStateAborted + SpecStatePanicked + SpecStateInterrupted + SpecStateTimedout +) + +var ssEnumSupport = NewEnumSupport(map[uint]string{ + uint(SpecStateInvalid): "INVALID SPEC STATE", + uint(SpecStatePending): "pending", + uint(SpecStateSkipped): "skipped", + uint(SpecStatePassed): "passed", + uint(SpecStateFailed): "failed", + uint(SpecStateAborted): "aborted", + uint(SpecStatePanicked): "panicked", + uint(SpecStateInterrupted): "interrupted", + uint(SpecStateTimedout): "timedout", +}) + +func (ss SpecState) String() string { + return ssEnumSupport.String(uint(ss)) +} +func (ss *SpecState) UnmarshalJSON(b []byte) error { + out, err := ssEnumSupport.UnmarshJSON(b) + *ss = SpecState(out) + return err +} +func (ss SpecState) MarshalJSON() ([]byte, error) { + return ssEnumSupport.MarshJSON(uint(ss)) +} + +var SpecStateFailureStates = SpecStateFailed | SpecStateTimedout | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted + +func (ss SpecState) Is(states SpecState) bool { + return ss&states != 0 +} + +// ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace +type ProgressReport struct { + Message string + ParallelProcess int + RunningInParallel bool + + Time time.Time + + ContainerHierarchyTexts []string + LeafNodeText string + LeafNodeLocation CodeLocation + SpecStartTime time.Time + + CurrentNodeType NodeType + CurrentNodeText string + CurrentNodeLocation CodeLocation + CurrentNodeStartTime time.Time + + CurrentStepText string + CurrentStepLocation CodeLocation + CurrentStepStartTime time.Time + + AdditionalReports []string + + CapturedGinkgoWriterOutput string `json:",omitempty"` + GinkgoWriterOffset int + + Goroutines []Goroutine +} + +func (pr ProgressReport) IsZero() bool { + return pr.CurrentNodeType == NodeTypeInvalid +} + +func (pr ProgressReport) SpecGoroutine() Goroutine { + for _, goroutine := range pr.Goroutines { + if goroutine.IsSpecGoroutine { + return goroutine + } + } + return Goroutine{} +} + +func (pr ProgressReport) HighlightedGoroutines() []Goroutine { + out := []Goroutine{} + for _, goroutine := range pr.Goroutines { + if goroutine.IsSpecGoroutine || !goroutine.HasHighlights() { + continue + } + out = append(out, goroutine) + } + return out +} + +func (pr ProgressReport) OtherGoroutines() []Goroutine { + out := []Goroutine{} + for _, goroutine := range pr.Goroutines { + if goroutine.IsSpecGoroutine || goroutine.HasHighlights() { + continue + } + out = append(out, goroutine) + } + return out +} + +func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport { + out := pr + out.CapturedGinkgoWriterOutput = "" + return out +} + +type Goroutine struct { + ID uint64 + State string + Stack []FunctionCall + IsSpecGoroutine bool +} + +func (g Goroutine) IsZero() bool { + return g.ID == 0 +} + +func (g Goroutine) HasHighlights() bool { + for _, fc := range g.Stack { + if fc.Highlight { + return true + } + } + + return false +} + +type FunctionCall struct { + Function string + Filename string + Line int + Highlight bool `json:",omitempty"` + Source []string `json:",omitempty"` + SourceHighlight int `json:",omitempty"` +} + +// NodeType captures the type of a given Ginkgo Node +type NodeType uint + +const ( + NodeTypeInvalid NodeType = 0 + + NodeTypeContainer NodeType = 1 << iota + NodeTypeIt + + NodeTypeBeforeEach + NodeTypeJustBeforeEach + NodeTypeAfterEach + NodeTypeJustAfterEach + + NodeTypeBeforeAll + NodeTypeAfterAll + + NodeTypeBeforeSuite + NodeTypeSynchronizedBeforeSuite + NodeTypeAfterSuite + NodeTypeSynchronizedAfterSuite + + NodeTypeReportBeforeEach + NodeTypeReportAfterEach + NodeTypeReportAfterSuite + + NodeTypeCleanupInvalid + NodeTypeCleanupAfterEach + NodeTypeCleanupAfterAll + NodeTypeCleanupAfterSuite +) + +var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt +var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite +var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite +var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportAfterSuite + +var ntEnumSupport = NewEnumSupport(map[uint]string{ + uint(NodeTypeInvalid): "INVALID NODE TYPE", + uint(NodeTypeContainer): "Container", + uint(NodeTypeIt): "It", + uint(NodeTypeBeforeEach): "BeforeEach", + uint(NodeTypeJustBeforeEach): "JustBeforeEach", + uint(NodeTypeAfterEach): "AfterEach", + uint(NodeTypeJustAfterEach): "JustAfterEach", + uint(NodeTypeBeforeAll): "BeforeAll", + uint(NodeTypeAfterAll): "AfterAll", + uint(NodeTypeBeforeSuite): "BeforeSuite", + uint(NodeTypeSynchronizedBeforeSuite): "SynchronizedBeforeSuite", + uint(NodeTypeAfterSuite): "AfterSuite", + uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite", + uint(NodeTypeReportBeforeEach): "ReportBeforeEach", + uint(NodeTypeReportAfterEach): "ReportAfterEach", + uint(NodeTypeReportAfterSuite): "ReportAfterSuite", + uint(NodeTypeCleanupInvalid): "DeferCleanup", + uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)", + uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)", + uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)", +}) + +func (nt NodeType) String() string { + return ntEnumSupport.String(uint(nt)) +} +func (nt *NodeType) UnmarshalJSON(b []byte) error { + out, err := ntEnumSupport.UnmarshJSON(b) + *nt = NodeType(out) + return err +} +func (nt NodeType) MarshalJSON() ([]byte, error) { + return ntEnumSupport.MarshJSON(uint(nt)) +} + +func (nt NodeType) Is(nodeTypes NodeType) bool { + return nt&nodeTypes != 0 +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/version.go b/vendor/github.com/onsi/ginkgo/v2/types/version.go new file mode 100644 index 00000000000..7ba384a09a6 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/types/version.go @@ -0,0 +1,3 @@ +package types + +const VERSION = "2.4.0" diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/LICENSE b/vendor/github.com/quic-go/qtls-go1-19/LICENSE similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-16/LICENSE rename to vendor/github.com/quic-go/qtls-go1-19/LICENSE diff --git a/vendor/github.com/quic-go/qtls-go1-19/README.md b/vendor/github.com/quic-go/qtls-go1-19/README.md new file mode 100644 index 00000000000..bf41f1c5f1f --- /dev/null +++ b/vendor/github.com/quic-go/qtls-go1-19/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-19.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-19) +[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/alert.go b/vendor/github.com/quic-go/qtls-go1-19/alert.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-16/alert.go rename to vendor/github.com/quic-go/qtls-go1-19/alert.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/auth.go b/vendor/github.com/quic-go/qtls-go1-19/auth.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-19/auth.go rename to vendor/github.com/quic-go/qtls-go1-19/auth.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cfkem.go b/vendor/github.com/quic-go/qtls-go1-19/cfkem.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/cfkem.go rename to vendor/github.com/quic-go/qtls-go1-19/cfkem.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go b/vendor/github.com/quic-go/qtls-go1-19/cipher_suites.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-19/cipher_suites.go rename to vendor/github.com/quic-go/qtls-go1-19/cipher_suites.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/common.go b/vendor/github.com/quic-go/qtls-go1-19/common.go similarity index 99% rename from vendor/github.com/marten-seemann/qtls-go1-19/common.go rename to vendor/github.com/quic-go/qtls-go1-19/common.go index 2484da0579d..63e391bf62a 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/common.go +++ b/vendor/github.com/quic-go/qtls-go1-19/common.go @@ -346,7 +346,7 @@ type clientSessionState struct { // SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which // are supported via this interface. // -//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-19 ClientSessionCache" +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/quic-go/qtls-go1-19 ClientSessionCache" type ClientSessionCache = tls.ClientSessionCache // SignatureScheme is a tls.SignatureScheme @@ -1412,7 +1412,7 @@ func leafCertificate(c *Certificate) (*x509.Certificate, error) { } type handshakeMessage interface { - marshal() []byte + marshal() ([]byte, error) unmarshal([]byte) bool } diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/conn.go b/vendor/github.com/quic-go/qtls-go1-19/conn.go similarity index 97% rename from vendor/github.com/marten-seemann/qtls-go1-19/conn.go rename to vendor/github.com/quic-go/qtls-go1-19/conn.go index 5a17f7a14f3..19f24e95f2b 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/conn.go +++ b/vendor/github.com/quic-go/qtls-go1-19/conn.go @@ -1041,25 +1041,46 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { return n, nil } -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { + data, err := msg.marshal() + if err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if transcript != nil { + transcript.Write(data) + } + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - if typ == recordTypeChangeCipherSpec { - return len(data), nil - } return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) } + return c.writeRecordLocked(recordTypeHandshake, data) +} + +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and +// updates the record layer state. +func (c *Conn) writeChangeCipherRecord() error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + c.out.Lock() defer c.out.Unlock() - - return c.writeRecordLocked(typ, data) + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) + return err } // readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (any, error) { +// the record layer. If transcript is non-nil, the message +// is written to the passed transcriptHash. +func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { var data []byte if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { var err error @@ -1147,6 +1168,11 @@ func (c *Conn) readHandshake() (any, error) { if !m.unmarshal(data) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + + if transcript != nil { + transcript.Write(data) + } + return m, nil } @@ -1222,7 +1248,7 @@ func (c *Conn) handleRenegotiation() error { return errors.New("tls: internal error: unexpected renegotiation") } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1272,7 +1298,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleRenegotiation() } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1308,7 +1334,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { defer c.out.Unlock() msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) if err != nil { // Surface the error at the next write. c.out.setErrorLocked(err) diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cpu.go b/vendor/github.com/quic-go/qtls-go1-19/cpu.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/cpu.go rename to vendor/github.com/quic-go/qtls-go1-19/cpu.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go b/vendor/github.com/quic-go/qtls-go1-19/cpu_other.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/cpu_other.go rename to vendor/github.com/quic-go/qtls-go1-19/cpu_other.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go b/vendor/github.com/quic-go/qtls-go1-19/handshake_client.go similarity index 92% rename from vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go rename to vendor/github.com/quic-go/qtls-go1-19/handshake_client.go index c2ded225a68..4f52ab992eb 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client.go +++ b/vendor/github.com/quic-go/qtls-go1-19/handshake_client.go @@ -144,22 +144,10 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) var secret clientKeySharePrivate if hello.supportedVersions[0] == VersionTLS13 { - var suites []uint16 - for _, suiteID := range configCipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - suites = append(suites, suiteID) - } - } - } - if len(suites) > 0 { - hello.cipherSuites = suites + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) } else { - if hasAESGCMHardwareSupport { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) - } else { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) - } + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) } curveID := config.curvePreferences()[0] @@ -212,7 +200,10 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } c.serverName = hello.serverName - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } if cacheKey != "" && session != nil { var deletedTicket bool if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { @@ -222,11 +213,14 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { h := suite.hash.New() - h.Write(hello.marshal()) + helloBytes, err := hello.marshal() + if err != nil { + return err + } + h.Write(helloBytes) clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) return err } } @@ -246,11 +240,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(hello, nil); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -343,9 +338,9 @@ func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max ea } func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *clientSessionState, earlySecret, binderKey []byte) { + session *clientSessionState, earlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil + return "", nil, nil, nil, nil } hello.ticketSupported = true @@ -360,14 +355,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. if c.handshakes != 0 { - return "", nil, nil, nil + return "", nil, nil, nil, nil } // Try to resume a previously negotiated TLS session, if available. cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) sess, ok := c.config.ClientSessionCache.Get(cacheKey) if !ok || sess == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } session = fromClientSessionState(sess) @@ -378,7 +373,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, maxEarlyData, appData, ok = c.decodeSessionState(session) if !ok { // delete it, if parsing failed c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -391,7 +386,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !versOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that the cached server certificate is not expired, and that it's @@ -400,16 +395,16 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, if !c.config.InsecureSkipVerify { if len(session.verifiedChains) == 0 { // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } serverCert := session.serverCertificates[0] if c.config.time().After(serverCert.NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -417,7 +412,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // In TLS 1.2 the cipher suite must match the resumed session. Ensure we // are still offering it. if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } hello.sessionTicket = session.sessionTicket @@ -427,14 +422,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // Check that the session ticket is not expired. if c.config.time().After(session.useBy) { c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // In TLS 1.3 the KDF hash must match the resumed session. Ensure we // offer at least one cipher suite with that hash. cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) if cipherSuite == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } cipherSuiteOk := false for _, offeredID := range hello.cipherSuites { @@ -445,7 +440,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !cipherSuiteOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. @@ -466,9 +461,15 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 } transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) + helloBytes, err := hello.marshalWithoutBinders() + if err != nil { + return "", nil, nil, nil, err + } + transcript.Write(helloBytes) pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) + if err := hello.updateBinders(pskBinders); err != nil { + return "", nil, nil, nil, err + } if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { c.extraConfig.SetAppDataFromSessionState(appData) @@ -516,8 +517,12 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { + return err + } + if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { + return err + } c.buffering = true c.didResume = isResume @@ -588,7 +593,7 @@ func (hs *clientHandshakeState) pickCipherSuite() error { func (hs *clientHandshakeState) doFullHandshake() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -597,9 +602,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -617,11 +621,10 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return errors.New("tls: received unexpected CertificateStatus message") } - hs.finishedHash.Write(cs.marshal()) c.ocspResponse = cs.response - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -650,14 +653,13 @@ func (hs *clientHandshakeState) doFullHandshake() error { skx, ok := msg.(*serverKeyExchangeMsg) if ok { - hs.finishedHash.Write(skx.marshal()) err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) if err != nil { c.sendAlert(alertUnexpectedMessage) return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -668,7 +670,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { certReq, ok := msg.(*certificateRequestMsg) if ok { certRequested = true - hs.finishedHash.Write(certReq.marshal()) cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { @@ -676,7 +677,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -687,7 +688,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(shd, msg) } - hs.finishedHash.Write(shd.marshal()) // If the server requested a certificate then we have to send a // Certificate message, even if it's empty because we don't have a @@ -695,8 +695,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { if certRequested { certMsg = new(certificateMsg) certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } } @@ -707,8 +706,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { return err } } @@ -755,8 +753,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { return err } } @@ -891,7 +888,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -907,7 +907,11 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { c.sendAlert(alertHandshakeFailure) return errors.New("tls: server's Finished message was incorrect") } - hs.finishedHash.Write(serverFinished.marshal()) + + if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -918,7 +922,7 @@ func (hs *clientHandshakeState) readSessionTicket() error { } c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -927,7 +931,6 @@ func (hs *clientHandshakeState) readSessionTicket() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(sessionTicketMsg, msg) } - hs.finishedHash.Write(sessionTicketMsg.marshal()) hs.session = &clientSessionState{ sessionTicket: sessionTicketMsg.ticket, @@ -947,14 +950,13 @@ func (hs *clientHandshakeState) readSessionTicket() error { func (hs *clientHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } copy(out, finished.verifyData) diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go b/vendor/github.com/quic-go/qtls-go1-19/handshake_client_tls13.go similarity index 92% rename from vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go rename to vendor/github.com/quic-go/qtls-go1-19/handshake_client_tls13.go index 299f557d965..5c484a0e62d 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_client_tls13.go +++ b/vendor/github.com/quic-go/qtls-go1-19/handshake_client_tls13.go @@ -70,7 +70,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) + + if err := transcriptMsg(hs.hello, hs.transcript); err != nil { + return err + } if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { if err := hs.sendDummyChangeCipherSpec(); err != nil { @@ -81,7 +84,9 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } } - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } c.buffering = true if err := hs.processServerHello(); err != nil { @@ -188,8 +193,7 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } // processHelloRetryRequest handles the HRR in hs.serverHello, modifies and @@ -206,7 +210,9 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } // The only HelloRetryRequest extensions we support are key_share and // cookie, and clients must abort the handshake if the HRR would not result @@ -288,10 +294,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { transcript := hs.suite.hash.New() transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } + helloBytes, err := hs.hello.marshalWithoutBinders() + if err != nil { + return err + } + transcript.Write(helloBytes) pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) + if err := hs.hello.updateBinders(pskBinders); err != nil { + return err + } } else { // Server selected a cipher suite incompatible with the PSK. hs.hello.pskIdentities = nil @@ -303,13 +317,12 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.extraConfig.Rejected0RTT() } hs.hello.earlyData = false // disable 0-RTT - - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -409,6 +422,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if !hs.usingPSK { earlySecret = hs.suite.extract(nil, nil) } + handshakeSecret := hs.suite.extract(sharedKey, hs.suite.deriveSecret(earlySecret, "derived", nil)) @@ -441,7 +455,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { func (hs *clientHandshakeStateTLS13) readServerParameters() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -459,7 +473,6 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) } - hs.transcript.Write(encryptedExtensions.marshal()) if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { c.sendAlert(alertUnsupportedExtension) @@ -495,18 +508,16 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return nil } - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } certReq, ok := msg.(*certificateRequestMsgTLS13) if ok { - hs.transcript.Write(certReq.marshal()) - hs.certReq = certReq - msg, err = c.readHandshake() + msg, err = c.readHandshake(hs.transcript) if err != nil { return err } @@ -521,7 +532,6 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { c.sendAlert(alertDecodeError) return errors.New("tls: received empty certificates message") } - hs.transcript.Write(certMsg.marshal()) c.scts = certMsg.certificate.SignedCertificateTimestamps c.ocspResponse = certMsg.certificate.OCSPStaple @@ -530,7 +540,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return err } - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -561,7 +574,9 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return errors.New("tls: invalid signature by the server certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } return nil } @@ -569,7 +584,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { func (hs *clientHandshakeStateTLS13) readServerFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -586,7 +604,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { return errors.New("tls: invalid server finished hash") } - hs.transcript.Write(finished.marshal()) + if err := transcriptMsg(finished, hs.transcript); err != nil { + return err + } // Derive secrets that take context through the server Finished. @@ -636,8 +656,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -674,8 +693,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -689,8 +707,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go b/vendor/github.com/quic-go/qtls-go1-19/handshake_messages.go similarity index 76% rename from vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go rename to vendor/github.com/quic-go/qtls-go1-19/handshake_messages.go index 07193c8efc4..c69fcefda25 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_messages.go +++ b/vendor/github.com/quic-go/qtls-go1-19/handshake_messages.go @@ -5,6 +5,7 @@ package qtls import ( + "errors" "fmt" "strings" @@ -95,9 +96,187 @@ type clientHelloMsg struct { additionalExtensions []Extension } -func (m *clientHelloMsg) marshal() []byte { +func (m *clientHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + exts.AddUint16(extensionServerName) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(0) // name_type = host_name + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + exts.AddUint16(extensionStatusRequest) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(1) // status_type = ocsp + exts.AddUint16(0) // empty responder_id_list + exts.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + exts.AddUint16(extensionSupportedCurves) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + exts.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + exts.AddUint16(extensionSessionTicket) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + exts.AddUint16(extensionSignatureAlgorithms) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + exts.AddUint16(extensionSignatureAlgorithmsCert) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + exts.AddUint16(extensionSCT) + exts.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + exts.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, ks := range m.keyShares { + exts.AddUint16(uint16(ks.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + exts.AddUint16(extensionEarlyData) + exts.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + exts.AddUint16(extensionPSKModes) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + exts.AddUint16(ext.Type) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(psk.label) + }) + exts.AddUint32(psk.obfuscatedTicketAge) + } + }) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(binder) + }) + } + }) + }) + } + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -117,225 +296,53 @@ func (m *clientHelloMsg) marshal() []byte { b.AddBytes(m.compressionMethods) }) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } // marshalWithoutBinders returns the ClientHello through the // PreSharedKeyExtension.identities field, according to RFC 8446, Section // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { +func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen := 2 // uint16 length prefix for _, binder := range m.pskBinders { bindersLen += 1 // uint8 length prefix bindersLen += len(binder) } - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] + fullMessage, err := m.marshal() + if err != nil { + return nil, err + } + return fullMessage[:len(fullMessage)-bindersLen], nil } // updateBinders updates the m.pskBinders field, if necessary updating the // cached marshaled representation. The supplied binders must have the same // length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } for i := range m.pskBinders { if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } } m.pskBinders = pskBinders if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) + helloBytes, err := m.marshalWithoutBinders() + if err != nil { + return err + } + lenWithoutBinders := len(helloBytes) b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, binder := range m.pskBinders { @@ -345,9 +352,11 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { } }) if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - panic("tls: internal error: failed to update binders") + return errors.New("tls: internal error: failed to update binders") } } + + return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { @@ -625,9 +634,98 @@ type serverHelloMsg struct { selectedGroup CurveID } -func (m *serverHelloMsg) marshal() []byte { +func (m *serverHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if m.ocspStapling { + exts.AddUint16(extensionStatusRequest) + exts.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + exts.AddUint16(extensionSessionTicket) + exts.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + exts.AddUint16(extensionSCT) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sct := range m.scts { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.serverShare.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -641,104 +739,15 @@ func (m *serverHelloMsg) marshal() []byte { b.AddUint16(m.cipherSuite) b.AddUint8(m.compressionMethod) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } func (m *serverHelloMsg) unmarshal(data []byte) bool { @@ -865,9 +874,9 @@ type encryptedExtensionsMsg struct { additionalExtensions []Extension } -func (m *encryptedExtensionsMsg) marshal() []byte { +func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -898,8 +907,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { @@ -949,10 +959,10 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { type endOfEarlyDataMsg struct{} -func (m *endOfEarlyDataMsg) marshal() []byte { +func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeEndOfEarlyData - return x + return x, nil } func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { @@ -964,9 +974,9 @@ type keyUpdateMsg struct { updateRequested bool } -func (m *keyUpdateMsg) marshal() []byte { +func (m *keyUpdateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -979,8 +989,9 @@ func (m *keyUpdateMsg) marshal() []byte { } }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *keyUpdateMsg) unmarshal(data []byte) bool { @@ -1012,9 +1023,9 @@ type newSessionTicketMsgTLS13 struct { maxEarlyData uint32 } -func (m *newSessionTicketMsgTLS13) marshal() []byte { +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1039,8 +1050,9 @@ func (m *newSessionTicketMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { @@ -1093,9 +1105,9 @@ type certificateRequestMsgTLS13 struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsgTLS13) marshal() []byte { +func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1154,8 +1166,9 @@ func (m *certificateRequestMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { @@ -1239,9 +1252,9 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) marshal() (x []byte) { +func (m *certificateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var i int @@ -1250,7 +1263,7 @@ func (m *certificateMsg) marshal() (x []byte) { } length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificate x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1271,7 +1284,7 @@ func (m *certificateMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1318,9 +1331,9 @@ type certificateMsgTLS13 struct { scts bool } -func (m *certificateMsgTLS13) marshal() []byte { +func (m *certificateMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1338,8 +1351,9 @@ func (m *certificateMsgTLS13) marshal() []byte { marshalCertificate(b, certificate) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1462,9 +1476,9 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) marshal() []byte { +func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.key) x := make([]byte, length+4) @@ -1475,7 +1489,7 @@ func (m *serverKeyExchangeMsg) marshal() []byte { copy(x[4:], m.key) m.raw = x - return x + return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1492,9 +1506,9 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) marshal() []byte { +func (m *certificateStatusMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1506,8 +1520,9 @@ func (m *certificateStatusMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateStatusMsg) unmarshal(data []byte) bool { @@ -1526,10 +1541,10 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) marshal() []byte { +func (m *serverHelloDoneMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeServerHelloDone - return x + return x, nil } func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { @@ -1541,9 +1556,9 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) marshal() []byte { +func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.ciphertext) x := make([]byte, length+4) @@ -1554,7 +1569,7 @@ func (m *clientKeyExchangeMsg) marshal() []byte { copy(x[4:], m.ciphertext) m.raw = x - return x + return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1575,9 +1590,9 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) marshal() []byte { +func (m *finishedMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1586,8 +1601,9 @@ func (m *finishedMsg) marshal() []byte { b.AddBytes(m.verifyData) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *finishedMsg) unmarshal(data []byte) bool { @@ -1609,9 +1625,9 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) marshal() (x []byte) { +func (m *certificateRequestMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 4346, Section 7.4.4. @@ -1626,7 +1642,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { length += 2 + 2*len(m.supportedSignatureAlgorithms) } - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificateRequest x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1661,7 +1677,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { @@ -1747,9 +1763,9 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) marshal() (x []byte) { +func (m *certificateVerifyMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1763,8 +1779,9 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { @@ -1787,15 +1804,15 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) marshal() (x []byte) { +func (m *newSessionTicketMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeNewSessionTicket x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1806,7 +1823,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) { m.raw = x - return + return m.raw, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { @@ -1834,10 +1851,25 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool { type helloRequestMsg struct { } -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} +func (*helloRequestMsg) marshal() ([]byte, error) { + return []byte{typeHelloRequest, 0, 0, 0}, nil } func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } + +type transcriptHash interface { + Write([]byte) (int, error) +} + +// transcriptMsg is a helper used to marshal and hash messages which typically +// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + data, err := msg.marshal() + if err != nil { + return err + } + h.Write(data) + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go b/vendor/github.com/quic-go/qtls-go1-19/handshake_server.go similarity index 92% rename from vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go rename to vendor/github.com/quic-go/qtls-go1-19/handshake_server.go index b363d53fef5..f0c0aa01981 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server.go +++ b/vendor/github.com/quic-go/qtls-go1-19/handshake_server.go @@ -138,7 +138,9 @@ func (hs *serverHandshakeState) handshake() error { // readClientHello reads a ClientHello message and selects the protocol version. func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - msg, err := c.readHandshake() + // clientHelloMsg is included in the transcript, but we haven't initialized + // it yet. The respective handshake functions will record it themselves. + msg, err := c.readHandshake(nil) if err != nil { return nil, err } @@ -494,9 +496,10 @@ func (hs *serverHandshakeState) doResumeHandshake() error { hs.hello.ticketSupported = hs.sessionState.usedOldKey hs.finishedHash = newFinishedHash(c.vers, hs.suite) hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } @@ -534,24 +537,23 @@ func (hs *serverHandshakeState) doFullHandshake() error { // certificates won't be used. hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } certMsg := new(certificateMsg) certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { return err } } @@ -563,8 +565,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { return err } if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err } } @@ -590,15 +591,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.finishedHash.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil { return err } } helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { return err } @@ -608,7 +607,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { var pub crypto.PublicKey // public key for client auth, if any - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -621,7 +620,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) if err := c.processCertsFromClient(Certificate{ Certificate: certMsg.certificates, @@ -632,7 +630,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { pub = c.peerCertificates[0].PublicKey } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -650,7 +648,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(ckx, msg) } - hs.finishedHash.Write(ckx.marshal()) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) if err != nil { @@ -670,7 +667,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { // to the client's certificate. This allows us to verify that the client is in // possession of the private key of the certificate. if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -705,7 +705,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.finishedHash.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { + return err + } } hs.finishedHash.discardHandshakeBuffer() @@ -745,7 +747,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -762,7 +767,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash.Write(clientFinished.marshal()) + if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -796,14 +804,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { masterSecret: hs.masterSecret, certificates: certsFromClient, } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return err + } + m.ticket, err = c.encryptTicket(stateBytes) if err != nil { return err } - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { return err } @@ -813,14 +823,13 @@ func (hs *serverHandshakeState) sendSessionTicket() error { func (hs *serverHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go b/vendor/github.com/quic-go/qtls-go1-19/handshake_server_tls13.go similarity index 92% rename from vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go rename to vendor/github.com/quic-go/qtls-go1-19/handshake_server_tls13.go index 9480fd68812..9a2b46c14e9 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/handshake_server_tls13.go +++ b/vendor/github.com/quic-go/qtls-go1-19/handshake_server_tls13.go @@ -156,27 +156,14 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { hs.hello.sessionId = hs.clientHello.sessionId hs.hello.compressionMethod = compressionNone - if hs.suite == nil { - var preferenceList []uint16 - for _, suiteID := range c.config.CipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - preferenceList = append(preferenceList, suiteID) - break - } - } - } - if len(preferenceList) == 0 { - preferenceList = defaultCipherSuitesTLS13 - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = defaultCipherSuitesTLS13NoAES - } - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) - if hs.suite != nil { - break - } + preferenceList := defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) + if hs.suite != nil { + break } } if hs.suite == nil { @@ -353,7 +340,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { c.sendAlert(alertInternalError) return errors.New("tls: internal error: failed to clone hash") } - transcript.Write(hs.clientHello.marshalWithoutBinders()) + clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + transcript.Write(clientHelloBytes) pskBinder := hs.suite.finishedHash(binderKey, transcript) if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { c.sendAlert(alertDecryptError) @@ -366,7 +358,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { } h := cloneHash(hs.transcript, hs.suite.hash) - h.Write(hs.clientHello.marshal()) + clientHelloWithBindersBytes, err := hs.clientHello.marshal() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + h.Write(clientHelloWithBindersBytes) if hs.encryptedExtensions.earlyData { clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) @@ -455,8 +452,7 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { @@ -466,7 +462,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } chHash := hs.transcript.Sum(nil) hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) @@ -482,8 +480,7 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) selectedGroup: selectedGroup, } - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { return err } @@ -491,7 +488,8 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) return err } - msg, err := c.readHandshake() + // clientHelloMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -587,9 +585,10 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { func (hs *serverHandshakeStateTLS13) sendServerParameters() error { c := hs.c - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } @@ -632,8 +631,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) } - hs.transcript.Write(hs.encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil { return err } @@ -662,8 +660,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { return err } } @@ -674,8 +671,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -706,8 +702,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -721,8 +716,7 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } @@ -784,7 +778,9 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { finishedMsg := &finishedMsg{ verifyData: hs.clientFinished, } - hs.transcript.Write(finishedMsg.marshal()) + if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { + return err + } if !hs.shouldSendSessionTickets() { return nil @@ -805,7 +801,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { return err } - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(m, nil); err != nil { return err } @@ -830,7 +826,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { // If we requested a client certificate, then the client must send a // certificate message. If it's empty, no CertificateVerify is sent. - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -840,7 +836,6 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.transcript.Write(certMsg.marshal()) if err := c.processCertsFromClient(certMsg.certificate); err != nil { return err @@ -854,7 +849,10 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -885,7 +883,9 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } } // If we waited until the client certificates to send session tickets, we @@ -900,7 +900,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { func (hs *serverHandshakeStateTLS13) readClientFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go b/vendor/github.com/quic-go/qtls-go1-19/key_agreement.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/key_agreement.go rename to vendor/github.com/quic-go/qtls-go1-19/key_agreement.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go b/vendor/github.com/quic-go/qtls-go1-19/key_schedule.go similarity index 86% rename from vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go rename to vendor/github.com/quic-go/qtls-go1-19/key_schedule.go index da13904a6e8..708bdc7c3e2 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-17/key_schedule.go +++ b/vendor/github.com/quic-go/qtls-go1-19/key_schedule.go @@ -8,6 +8,7 @@ import ( "crypto/elliptic" "crypto/hmac" "errors" + "fmt" "hash" "io" "math/big" @@ -42,8 +43,24 @@ func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []by hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(context) }) + hkdfLabelBytes, err := hkdfLabel.Bytes() + if err != nil { + // Rather than calling BytesOrPanic, we explicitly handle this error, in + // order to provide a reasonable error message. It should be basically + // impossible for this to panic, and routing errors back through the + // tree rooted in this function is quite painful. The labels are fixed + // size, and the context is either a fixed-length computed hash, or + // parsed from a field which has the same length limitation. As such, an + // error here is likely to only be caused during development. + // + // NOTE: another reasonable approach here might be to return a + // randomized slice if we encounter an error, which would break the + // connection, but avoid panicking. This would perhaps be safer but + // significantly more confusing to users. + panic(fmt.Errorf("failed to construct HKDF label: %s", err)) + } out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) if err != nil || n != length { panic("tls: HKDF-Expand-Label invocation failed unexpectedly") } diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/notboring.go b/vendor/github.com/quic-go/qtls-go1-19/notboring.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-19/notboring.go rename to vendor/github.com/quic-go/qtls-go1-19/notboring.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/prf.go b/vendor/github.com/quic-go/qtls-go1-19/prf.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-16/prf.go rename to vendor/github.com/quic-go/qtls-go1-19/prf.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/ticket.go b/vendor/github.com/quic-go/qtls-go1-19/ticket.go similarity index 96% rename from vendor/github.com/marten-seemann/qtls-go1-18/ticket.go rename to vendor/github.com/quic-go/qtls-go1-19/ticket.go index 81e8a52eac6..fe1c7a88437 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/ticket.go +++ b/vendor/github.com/quic-go/qtls-go1-19/ticket.go @@ -34,7 +34,7 @@ type sessionState struct { usedOldKey bool } -func (m *sessionState) marshal() []byte { +func (m *sessionState) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(m.vers) b.AddUint16(m.cipherSuite) @@ -49,7 +49,7 @@ func (m *sessionState) marshal() []byte { }) } }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionState) unmarshal(data []byte) bool { @@ -94,7 +94,7 @@ type sessionStateTLS13 struct { appData []byte } -func (m *sessionStateTLS13) marshal() []byte { +func (m *sessionStateTLS13) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(VersionTLS13) b.AddUint8(2) // revision @@ -111,7 +111,7 @@ func (m *sessionStateTLS13) marshal() []byte { b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(m.appData) }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionStateTLS13) unmarshal(data []byte) bool { @@ -227,8 +227,11 @@ func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, e if c.extraConfig != nil { state.maxEarlyData = c.extraConfig.MaxEarlyData } - var err error - m.label, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return nil, err + } + m.label, err = c.encryptTicket(stateBytes) if err != nil { return nil, err } @@ -270,5 +273,5 @@ func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { if err != nil { return nil, err } - return m.marshal(), nil + return m.marshal() } diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/tls.go b/vendor/github.com/quic-go/qtls-go1-19/tls.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/tls.go rename to vendor/github.com/quic-go/qtls-go1-19/tls.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go b/vendor/github.com/quic-go/qtls-go1-19/unsafe.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-16/unsafe.go rename to vendor/github.com/quic-go/qtls-go1-19/unsafe.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/LICENSE b/vendor/github.com/quic-go/qtls-go1-20/LICENSE similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/LICENSE rename to vendor/github.com/quic-go/qtls-go1-20/LICENSE diff --git a/vendor/github.com/quic-go/qtls-go1-20/README.md b/vendor/github.com/quic-go/qtls-go1-20/README.md new file mode 100644 index 00000000000..2beaa2f2364 --- /dev/null +++ b/vendor/github.com/quic-go/qtls-go1-20/README.md @@ -0,0 +1,6 @@ +# qtls + +[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-20.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-20) +[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml) + +This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/quic-go/quic-go). diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/alert.go b/vendor/github.com/quic-go/qtls-go1-20/alert.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/alert.go rename to vendor/github.com/quic-go/qtls-go1-20/alert.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/auth.go b/vendor/github.com/quic-go/qtls-go1-20/auth.go similarity index 98% rename from vendor/github.com/marten-seemann/qtls-go1-18/auth.go rename to vendor/github.com/quic-go/qtls-go1-20/auth.go index 1ef675fd37c..effc9aced81 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/auth.go +++ b/vendor/github.com/quic-go/qtls-go1-20/auth.go @@ -169,6 +169,7 @@ var rsaSignatureSchemes = []struct { // and optionally filtered by its explicit SupportedSignatureAlgorithms. // // This function must be kept in sync with supportedSignatureAlgorithms. +// FIPS filtering is applied in the caller, selectSignatureScheme. func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { priv, ok := cert.PrivateKey.(crypto.Signer) if !ok { @@ -241,6 +242,9 @@ func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureSche // Pick signature scheme in the peer's preference order, as our // preference order is not configurable. for _, preferredAlg := range peerAlgs { + if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) { + continue + } if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) { return preferredAlg, nil } diff --git a/vendor/github.com/quic-go/qtls-go1-20/cache.go b/vendor/github.com/quic-go/qtls-go1-20/cache.go new file mode 100644 index 00000000000..99e0c5fb862 --- /dev/null +++ b/vendor/github.com/quic-go/qtls-go1-20/cache.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +import ( + "crypto/x509" + "runtime" + "sync" + "sync/atomic" +) + +type cacheEntry struct { + refs atomic.Int64 + cert *x509.Certificate +} + +// certCache implements an intern table for reference counted x509.Certificates, +// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This +// allows for a single x509.Certificate to be kept in memory and referenced from +// multiple Conns. Returned references should not be mutated by callers. Certificates +// are still safe to use after they are removed from the cache. +// +// Certificates are returned wrapped in a activeCert struct that should be held by +// the caller. When references to the activeCert are freed, the number of references +// to the certificate in the cache is decremented. Once the number of references +// reaches zero, the entry is evicted from the cache. +// +// The main difference between this implementation and CRYPTO_BUFFER_POOL is that +// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data, +// rather than specific structures. Since we only care about x509.Certificates, +// certCache is implemented as a specific cache, rather than a generic one. +// +// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h +// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c +// for the BoringSSL reference. +type certCache struct { + sync.Map +} + +var clientCertCache = new(certCache) + +// activeCert is a handle to a certificate held in the cache. Once there are +// no alive activeCerts for a given certificate, the certificate is removed +// from the cache by a finalizer. +type activeCert struct { + cert *x509.Certificate +} + +// active increments the number of references to the entry, wraps the +// certificate in the entry in a activeCert, and sets the finalizer. +// +// Note that there is a race between active and the finalizer set on the +// returned activeCert, triggered if active is called after the ref count is +// decremented such that refs may be > 0 when evict is called. We consider this +// safe, since the caller holding an activeCert for an entry that is no longer +// in the cache is fine, with the only side effect being the memory overhead of +// there being more than one distinct reference to a certificate alive at once. +func (cc *certCache) active(e *cacheEntry) *activeCert { + e.refs.Add(1) + a := &activeCert{e.cert} + runtime.SetFinalizer(a, func(_ *activeCert) { + if e.refs.Add(-1) == 0 { + cc.evict(e) + } + }) + return a +} + +// evict removes a cacheEntry from the cache. +func (cc *certCache) evict(e *cacheEntry) { + cc.Delete(string(e.cert.Raw)) +} + +// newCert returns a x509.Certificate parsed from der. If there is already a copy +// of the certificate in the cache, a reference to the existing certificate will +// be returned. Otherwise, a fresh certificate will be added to the cache, and +// the reference returned. The returned reference should not be mutated. +func (cc *certCache) newCert(der []byte) (*activeCert, error) { + if entry, ok := cc.Load(string(der)); ok { + return cc.active(entry.(*cacheEntry)), nil + } + + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + entry := &cacheEntry{cert: cert} + if entry, loaded := cc.LoadOrStore(string(der), entry); loaded { + return cc.active(entry.(*cacheEntry)), nil + } + return cc.active(entry), nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/cfkem.go b/vendor/github.com/quic-go/qtls-go1-20/cfkem.go similarity index 94% rename from vendor/github.com/marten-seemann/qtls-go1-19/cfkem.go rename to vendor/github.com/quic-go/qtls-go1-20/cfkem.go index dc2f2302286..b5e849224dc 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/cfkem.go +++ b/vendor/github.com/quic-go/qtls-go1-20/cfkem.go @@ -1,4 +1,4 @@ -// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code +// Copyright 2023 Cloudflare, Inc. All rights reserved. Use of this source code // is governed by a BSD-style license that can be found in the LICENSE file. // // Glue to add Circl's (post-quantum) hybrid KEMs. @@ -23,6 +23,7 @@ import ( "github.com/cloudflare/circl/kem" "github.com/cloudflare/circl/kem/hybrid" + "crypto/ecdh" "crypto/tls" "fmt" "io" @@ -30,7 +31,7 @@ import ( "time" ) -// Either ecdheParameters or kem.PrivateKey +// Either *ecdh.PrivateKey or kem.PrivateKey type clientKeySharePrivate interface{} var ( @@ -59,8 +60,12 @@ func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { panic("cfkem: internal error: don't know CurveID for this KEM") } return ret - case ecdheParameters: - return v.CurveID() + case *ecdh.PrivateKey: + ret, ok := curveIDForCurve(v.Curve()) + if !ok { + panic("cfkem: internal error: unknown curve") + } + return ret default: panic("cfkem: internal error: unknown clientKeySharePrivate") } diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go b/vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go similarity index 91% rename from vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go rename to vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go index e0be5147402..43d21315735 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/cipher_suites.go +++ b/vendor/github.com/quic-go/qtls-go1-20/cipher_suites.go @@ -226,57 +226,56 @@ var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map. // // - Anything else comes before RC4 // -// RC4 has practically exploitable biases. See https://www.rc4nomore.com. +// RC4 has practically exploitable biases. See https://www.rc4nomore.com. // // - Anything else comes before CBC_SHA256 // -// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 -// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and -// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. +// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13 +// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and +// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. // // - Anything else comes before 3DES // -// 3DES has 64-bit blocks, which makes it fundamentally susceptible to -// birthday attacks. See https://sweet32.info. +// 3DES has 64-bit blocks, which makes it fundamentally susceptible to +// birthday attacks. See https://sweet32.info. // // - ECDHE comes before anything else // -// Once we got the broken stuff out of the way, the most important -// property a cipher suite can have is forward secrecy. We don't -// implement FFDHE, so that means ECDHE. +// Once we got the broken stuff out of the way, the most important +// property a cipher suite can have is forward secrecy. We don't +// implement FFDHE, so that means ECDHE. // // - AEADs come before CBC ciphers // -// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites -// are fundamentally fragile, and suffered from an endless sequence of -// padding oracle attacks. See https://eprint.iacr.org/2015/1129, -// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and -// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. +// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites +// are fundamentally fragile, and suffered from an endless sequence of +// padding oracle attacks. See https://eprint.iacr.org/2015/1129, +// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and +// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/. // // - AES comes before ChaCha20 // -// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster -// than ChaCha20Poly1305. +// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster +// than ChaCha20Poly1305. // -// When AES hardware is not available, AES-128-GCM is one or more of: much -// slower, way more complex, and less safe (because not constant time) -// than ChaCha20Poly1305. +// When AES hardware is not available, AES-128-GCM is one or more of: much +// slower, way more complex, and less safe (because not constant time) +// than ChaCha20Poly1305. // -// We use this list if we think both peers have AES hardware, and -// cipherSuitesPreferenceOrderNoAES otherwise. +// We use this list if we think both peers have AES hardware, and +// cipherSuitesPreferenceOrderNoAES otherwise. // // - AES-128 comes before AES-256 // -// The only potential advantages of AES-256 are better multi-target -// margins, and hypothetical post-quantum properties. Neither apply to -// TLS, and AES-256 is slower due to its four extra rounds (which don't -// contribute to the advantages above). +// The only potential advantages of AES-256 are better multi-target +// margins, and hypothetical post-quantum properties. Neither apply to +// TLS, and AES-256 is slower due to its four extra rounds (which don't +// contribute to the advantages above). // // - ECDSA comes before RSA // -// The relative order of ECDSA and RSA cipher suites doesn't matter, -// as they depend on the certificate. Pick one to get a stable order. -// +// The relative order of ECDSA and RSA cipher suites doesn't matter, +// as they depend on the certificate. Pick one to get a stable order. var cipherSuitesPreferenceOrder = []uint16{ // AEADs w/ ECDHE TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, @@ -419,7 +418,9 @@ func cipherAES(key, iv []byte, isRead bool) any { // macSHA1 returns a SHA-1 based constant time MAC. func macSHA1(key []byte) hash.Hash { - return hmac.New(newConstantTimeHash(sha1.New), key) + h := sha1.New + h = newConstantTimeHash(h) + return hmac.New(h, key) } // macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and @@ -464,7 +465,7 @@ func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([ return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) } -// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce // before each call. type xorNonceAEAD struct { nonceMask [aeadNonceLength]byte @@ -507,7 +508,8 @@ func aeadAESGCM(key, noncePrefix []byte) aead { if err != nil { panic(err) } - aead, err := cipher.NewGCM(aes) + var aead cipher.AEAD + aead, err = cipher.NewGCM(aes) if err != nil { panic(err) } diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/common.go b/vendor/github.com/quic-go/qtls-go1-20/common.go similarity index 97% rename from vendor/github.com/marten-seemann/qtls-go1-18/common.go rename to vendor/github.com/quic-go/qtls-go1-20/common.go index d4334f9add2..074dd9dcef2 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/common.go +++ b/vendor/github.com/quic-go/qtls-go1-20/common.go @@ -181,11 +181,11 @@ const ( // hash function associated with the Ed25519 signature scheme. var directSigning crypto.Hash = 0 -// supportedSignatureAlgorithms contains the signature and hash algorithms that +// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that // the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+ // CertificateRequest. The two fields are merged to match with TLS 1.3. // Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. -var supportedSignatureAlgorithms = []SignatureScheme{ +var defaultSupportedSignatureAlgorithms = []SignatureScheme{ PSSWithSHA256, ECDSAWithP256AndSHA256, Ed25519, @@ -258,6 +258,8 @@ type connectionState struct { // On the client side, it can't be empty. On the server side, it can be // empty if Config.ClientAuth is not RequireAnyClientCert or // RequireAndVerifyClientCert. + // + // PeerCertificates and its contents should not be modified. PeerCertificates []*x509.Certificate // VerifiedChains is a list of one or more chains where the first element is @@ -267,6 +269,8 @@ type connectionState struct { // On the client side, it's set if Config.InsecureSkipVerify is false. On // the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven // (and the peer provided a certificate) or RequireAndVerifyClientCert. + // + // VerifiedChains and its contents should not be modified. VerifiedChains [][]*x509.Certificate // SignedCertificateTimestamps is a list of SCTs provided by the peer @@ -346,7 +350,7 @@ type clientSessionState struct { // SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which // are supported via this interface. // -//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/marten-seemann/qtls-go1-17 ClientSessionCache" +//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/quic-go/qtls-go1-20 ClientSessionCache" type ClientSessionCache = tls.ClientSessionCache // SignatureScheme is a tls.SignatureScheme @@ -544,6 +548,8 @@ type config struct { // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // best element of Certificates will be used. + // + // Once a Certificate is returned it should not be modified. GetCertificate func(*ClientHelloInfo) (*Certificate, error) // GetClientCertificate, if not nil, is called when a server requests a @@ -559,6 +565,8 @@ type config struct { // // GetClientCertificate may be called multiple times for the same // connection if renegotiation occurs or if TLS 1.3 is in use. + // + // Once a Certificate is returned it should not be modified. GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) // GetConfigForClient, if not nil, is called after a ClientHello is @@ -587,6 +595,8 @@ type config struct { // setting InsecureSkipVerify, or (for a server) when ClientAuth is // RequestClientCert or RequireAnyClientCert, then this callback will // be considered but the verifiedChains argument will always be nil. + // + // verifiedChains and its contents should not be modified. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error // VerifyConnection, if not nil, is called after normal certificate @@ -715,7 +725,7 @@ type config struct { // mutex protects sessionTicketKeys and autoSessionTicketKeys. mutex sync.RWMutex - // sessionTicketKeys contains zero or more ticket keys. If set, it means the + // sessionTicketKeys contains zero or more ticket keys. If set, it means // the keys were set with SessionTicketKey or SetSessionTicketKeys. The // first key is used for new tickets and any subsequent keys can be used to // decrypt old tickets. The slice contents are not protected by the mutex @@ -1038,6 +1048,9 @@ func (c *config) time() time.Time { } func (c *config) cipherSuites() []uint16 { + if needFIPS() { + return fipsCipherSuites(c) + } if c.CipherSuites != nil { return c.CipherSuites } @@ -1051,10 +1064,6 @@ var supportedVersions = []uint16{ VersionTLS10, } -// debugEnableTLS10 enables TLS 1.0. See issue 45428. -// We don't care about TLS1.0 in qtls. Always disable it. -var debugEnableTLS10 = false - // roleClient and roleServer are meant to call supportedVersions and parents // with more readability at the callsite. const roleClient = true @@ -1063,7 +1072,10 @@ const roleServer = false func (c *config) supportedVersions(isClient bool) []uint16 { versions := make([]uint16, 0, len(supportedVersions)) for _, v := range supportedVersions { - if (c == nil || c.MinVersion == 0) && !debugEnableTLS10 && + if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { + continue + } + if (c == nil || c.MinVersion == 0) && isClient && v < VersionTLS12 { continue } @@ -1103,6 +1115,9 @@ func supportedVersionsFromMax(maxVersion uint16) []uint16 { var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} func (c *config) curvePreferences() []CurveID { + if needFIPS() { + return fipsCurvePreferences(c) + } if c == nil || len(c.CurvePreferences) == 0 { return defaultCurvePreferences } @@ -1381,7 +1396,7 @@ func (c *config) writeKeyLog(label string, clientRandom, secret []byte) error { return nil } - logLine := []byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)) + logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret) writerMutex.Lock() _, err := c.KeyLogWriter.Write(logLine) @@ -1407,7 +1422,7 @@ func leafCertificate(c *Certificate) (*x509.Certificate, error) { } type handshakeMessage interface { - marshal() []byte + marshal() ([]byte, error) unmarshal([]byte) bool } @@ -1506,3 +1521,18 @@ func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlg } return false } + +// CertificateVerificationError is returned when certificate verification fails during the handshake. +type CertificateVerificationError struct { + // UnverifiedCertificates and its contents should not be modified. + UnverifiedCertificates []*x509.Certificate + Err error +} + +func (e *CertificateVerificationError) Error() string { + return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err) +} + +func (e *CertificateVerificationError) Unwrap() error { + return e.Err +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/conn.go b/vendor/github.com/quic-go/qtls-go1-20/conn.go similarity index 95% rename from vendor/github.com/marten-seemann/qtls-go1-18/conn.go rename to vendor/github.com/quic-go/qtls-go1-20/conn.go index 2b8c7307e94..656c83c71d5 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/conn.go +++ b/vendor/github.com/quic-go/qtls-go1-20/conn.go @@ -30,11 +30,10 @@ type Conn struct { isClient bool handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake - // handshakeStatus is 1 if the connection is currently transferring + // isHandshakeComplete is true if the connection is currently transferring // application data (i.e. is not currently processing a handshake). - // handshakeStatus == 1 implies handshakeErr == nil. - // This field is only to be accessed with sync/atomic. - handshakeStatus uint32 + // isHandshakeComplete is true implies handshakeErr == nil. + isHandshakeComplete atomic.Bool // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex handshakeErr error // error resulting from handshake @@ -52,6 +51,9 @@ type Conn struct { ocspResponse []byte // stapled OCSP response scts [][]byte // signed certificate timestamps from server peerCertificates []*x509.Certificate + // activeCertHandles contains the cache handles to certificates in + // peerCertificates that are used to track active references. + activeCertHandles []*activeCert // verifiedChains contains the certificate chains that we built, as // opposed to the ones presented by the server. verifiedChains [][]*x509.Certificate @@ -117,10 +119,9 @@ type Conn struct { // handshake, nor deliver application data. Protected by in.Mutex. retryCount int - // activeCall is an atomic int32; the low bit is whether Close has - // been called. the rest of the bits are the number of goroutines - // in Conn.Write. - activeCall int32 + // activeCall indicates whether Close has been call in the low bit. + // the rest of the bits are the number of goroutines in Conn.Write. + activeCall atomic.Int32 used0RTT bool @@ -621,12 +622,14 @@ func (c *Conn) readChangeCipherSpec() error { // readRecordOrCCS reads one or more TLS records from the connection and // updates the record layer state. Some invariants: -// * c.in must be locked -// * c.input must be empty +// - c.in must be locked +// - c.input must be empty +// // During the handshake one and only one of the following will happen: // - c.hand grows // - c.in.changeCipherSpec is called // - an error is returned +// // After the handshake one and only one of the following will happen: // - c.hand grows // - c.input is set @@ -635,7 +638,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { if c.in.err != nil { return c.in.err } - handshakeComplete := c.handshakeComplete() + handshakeComplete := c.isHandshakeComplete.Load() // This function modifies c.rawInput, which owns the c.input memory. if c.input.Len() != 0 { @@ -792,7 +795,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error { return nil } -// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like +// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like // a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { c.retryCount++ @@ -1039,25 +1042,46 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { return n, nil } -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { + data, err := msg.marshal() + if err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if transcript != nil { + transcript.Write(data) + } + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { - if typ == recordTypeChangeCipherSpec { - return len(data), nil - } return c.extraConfig.AlternativeRecordLayer.WriteRecord(data) } + return c.writeRecordLocked(recordTypeHandshake, data) +} + +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and +// updates the record layer state. +func (c *Conn) writeChangeCipherRecord() error { + if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { + return nil + } + c.out.Lock() defer c.out.Unlock() - - return c.writeRecordLocked(typ, data) + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) + return err } // readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (any, error) { +// the record layer. If transcript is non-nil, the message +// is written to the passed transcriptHash. +func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { var data []byte if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil { var err error @@ -1145,6 +1169,11 @@ func (c *Conn) readHandshake() (any, error) { if !m.unmarshal(data) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + + if transcript != nil { + transcript.Write(data) + } + return m, nil } @@ -1161,15 +1190,15 @@ var ( func (c *Conn) Write(b []byte) (int, error) { // interlock with Close below for { - x := atomic.LoadInt32(&c.activeCall) + x := c.activeCall.Load() if x&1 != 0 { return 0, net.ErrClosed } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { + if c.activeCall.CompareAndSwap(x, x+2) { break } } - defer atomic.AddInt32(&c.activeCall, -2) + defer c.activeCall.Add(-2) if err := c.Handshake(); err != nil { return 0, err @@ -1182,7 +1211,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return 0, alertInternalError } @@ -1220,7 +1249,7 @@ func (c *Conn) handleRenegotiation() error { return errors.New("tls: internal error: unexpected renegotiation") } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1252,7 +1281,7 @@ func (c *Conn) handleRenegotiation() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - atomic.StoreUint32(&c.handshakeStatus, 0) + c.isHandshakeComplete.Store(false) if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { c.handshakes++ } @@ -1270,7 +1299,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleRenegotiation() } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1306,7 +1335,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { defer c.out.Unlock() msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) if err != nil { // Surface the error at the next write. c.out.setErrorLocked(err) @@ -1374,11 +1407,11 @@ func (c *Conn) Close() error { // Interlock with Conn.Write above. var x int32 for { - x = atomic.LoadInt32(&c.activeCall) + x = c.activeCall.Load() if x&1 != 0 { return net.ErrClosed } - if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) { + if c.activeCall.CompareAndSwap(x, x|1) { break } } @@ -1393,7 +1426,7 @@ func (c *Conn) Close() error { } var alertErr error - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { if err := c.closeNotify(); err != nil { alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err) } @@ -1411,7 +1444,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com // called once the handshake has completed and does not call CloseWrite on the // underlying connection. Most callers should just use Close. func (c *Conn) CloseWrite() error { - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return errEarlyCloseWrite } @@ -1465,7 +1498,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { // Fast sync/atomic-based exit if there is no handshake in flight and the // last one succeeded without an error. Avoids the expensive context setup // and mutex for most Read and Write calls. - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { return nil } @@ -1508,7 +1541,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { if err := c.handshakeErr; err != nil { return err } - if c.handshakeComplete() { + if c.isHandshakeComplete.Load() { return nil } @@ -1524,10 +1557,10 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { c.flush() } - if c.handshakeErr == nil && !c.handshakeComplete() { + if c.handshakeErr == nil && !c.isHandshakeComplete.Load() { c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") } - if c.handshakeErr != nil && c.handshakeComplete() { + if c.handshakeErr != nil && c.isHandshakeComplete.Load() { panic("tls: internal error: handshake returned an error but is marked successful") } @@ -1550,7 +1583,7 @@ func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { func (c *Conn) connectionStateLocked() ConnectionState { var state connectionState - state.HandshakeComplete = c.handshakeComplete() + state.HandshakeComplete = c.isHandshakeComplete.Load() state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume @@ -1603,7 +1636,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if !c.handshakeComplete() { + if !c.isHandshakeComplete.Load() { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { @@ -1611,7 +1644,3 @@ func (c *Conn) VerifyHostname(host string) error { } return c.peerCertificates[0].VerifyHostname(host) } - -func (c *Conn) handshakeComplete() bool { - return atomic.LoadUint32(&c.handshakeStatus) == 1 -} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cpu.go b/vendor/github.com/quic-go/qtls-go1-20/cpu.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/cpu.go rename to vendor/github.com/quic-go/qtls-go1-20/cpu.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go b/vendor/github.com/quic-go/qtls-go1-20/cpu_other.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/cpu_other.go rename to vendor/github.com/quic-go/qtls-go1-20/cpu_other.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_client.go similarity index 90% rename from vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go rename to vendor/github.com/quic-go/qtls-go1-20/handshake_client.go index 35a1afede55..60a93e80255 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client.go +++ b/vendor/github.com/quic-go/qtls-go1-20/handshake_client.go @@ -19,7 +19,6 @@ import ( "io" "net" "strings" - "sync/atomic" "time" "golang.org/x/crypto/cryptobyte" @@ -38,6 +37,8 @@ type clientHandshakeState struct { session *clientSessionState } +var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme + func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { @@ -134,27 +135,18 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) } if hello.vers >= VersionTLS12 { - hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms + hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if testingOnlyForceClientHelloSignatureAlgorithms != nil { + hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } var secret clientKeySharePrivate if hello.supportedVersions[0] == VersionTLS13 { - var suites []uint16 - for _, suiteID := range configCipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - suites = append(suites, suiteID) - } - } - } - if len(suites) > 0 { - hello.cipherSuites = suites + if hasAESGCMHardwareSupport { + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) } else { - if hasAESGCMHardwareSupport { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) - } else { - hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) - } + hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) } curveID := config.curvePreferences()[0] @@ -172,15 +164,15 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error) hello.keyShares = []keyShare{{group: curveID, data: packedPk}} secret = sk } else { - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), curveID) if err != nil { return nil, nil, err } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} - secret = params + hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + secret = key } } @@ -201,13 +193,16 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { // need to be reset. c.didResume = false - hello, ecdheParams, err := c.makeClientHello() + hello, ecdheKey, err := c.makeClientHello() if err != nil { return err } c.serverName = hello.serverName - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } if cacheKey != "" && session != nil { var deletedTicket bool if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT { @@ -217,11 +212,14 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil { h := suite.hash.New() - h.Write(hello.marshal()) + helloBytes, err := hello.marshal() + if err != nil { + return err + } + h.Write(helloBytes) clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h) c.out.exportKey(Encryption0RTT, suite, clientEarlySecret) if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil { - c.sendAlert(alertInternalError) return err } } @@ -241,11 +239,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(hello, nil); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -278,7 +277,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { ctx: ctx, serverHello: serverHello, hello: hello, - keySharePrivate: ecdheParams, + keySharePrivate: ecdheKey, session: session, earlySecret: earlySecret, binderKey: binderKey, @@ -338,9 +337,9 @@ func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max ea } func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *clientSessionState, earlySecret, binderKey []byte) { + session *clientSessionState, earlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil + return "", nil, nil, nil, nil } hello.ticketSupported = true @@ -355,14 +354,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. if c.handshakes != 0 { - return "", nil, nil, nil + return "", nil, nil, nil, nil } // Try to resume a previously negotiated TLS session, if available. cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) sess, ok := c.config.ClientSessionCache.Get(cacheKey) if !ok || sess == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } session = fromClientSessionState(sess) @@ -373,7 +372,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, maxEarlyData, appData, ok = c.decodeSessionState(session) if !ok { // delete it, if parsing failed c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -386,7 +385,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !versOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that the cached server certificate is not expired, and that it's @@ -395,16 +394,16 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, if !c.config.InsecureSkipVerify { if len(session.verifiedChains) == 0 { // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } serverCert := session.serverCertificates[0] if c.config.time().After(serverCert.NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -412,7 +411,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // In TLS 1.2 the cipher suite must match the resumed session. Ensure we // are still offering it. if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } hello.sessionTicket = session.sessionTicket @@ -422,14 +421,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // Check that the session ticket is not expired. if c.config.time().After(session.useBy) { c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // In TLS 1.3 the KDF hash must match the resumed session. Ensure we // offer at least one cipher suite with that hash. cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) if cipherSuite == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } cipherSuiteOk := false for _, offeredID := range hello.cipherSuites { @@ -440,7 +439,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !cipherSuiteOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. @@ -461,9 +460,15 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0 } transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) + helloBytes, err := hello.marshalWithoutBinders() + if err != nil { + return "", nil, nil, nil, err + } + transcript.Write(helloBytes) pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) + if err := hello.updateBinders(pskBinders); err != nil { + return "", nil, nil, nil, err + } if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil { c.extraConfig.SetAppDataFromSessionState(appData) @@ -511,8 +516,12 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { + return err + } + if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { + return err + } c.buffering = true c.didResume = isResume @@ -565,7 +574,7 @@ func (hs *clientHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) return nil } @@ -583,7 +592,7 @@ func (hs *clientHandshakeState) pickCipherSuite() error { func (hs *clientHandshakeState) doFullHandshake() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -592,9 +601,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -612,11 +620,10 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return errors.New("tls: received unexpected CertificateStatus message") } - hs.finishedHash.Write(cs.marshal()) c.ocspResponse = cs.response - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -645,14 +652,13 @@ func (hs *clientHandshakeState) doFullHandshake() error { skx, ok := msg.(*serverKeyExchangeMsg) if ok { - hs.finishedHash.Write(skx.marshal()) err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) if err != nil { c.sendAlert(alertUnexpectedMessage) return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -663,7 +669,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { certReq, ok := msg.(*certificateRequestMsg) if ok { certRequested = true - hs.finishedHash.Write(certReq.marshal()) cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { @@ -671,7 +676,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -682,7 +687,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(shd, msg) } - hs.finishedHash.Write(shd.marshal()) // If the server requested a certificate then we have to send a // Certificate message, even if it's empty because we don't have a @@ -690,8 +694,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { if certRequested { certMsg = new(certificateMsg) certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } } @@ -702,8 +705,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { return err } } @@ -739,7 +741,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { } } - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) signOpts := crypto.SignerOpts(sigHash) if sigType == signatureRSAPSS { signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} @@ -750,8 +752,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { return err } } @@ -886,7 +887,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -902,7 +906,11 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { c.sendAlert(alertHandshakeFailure) return errors.New("tls: server's Finished message was incorrect") } - hs.finishedHash.Write(serverFinished.marshal()) + + if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -913,7 +921,7 @@ func (hs *clientHandshakeState) readSessionTicket() error { } c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -922,7 +930,6 @@ func (hs *clientHandshakeState) readSessionTicket() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(sessionTicketMsg, msg) } - hs.finishedHash.Write(sessionTicketMsg.marshal()) hs.session = &clientSessionState{ sessionTicket: sessionTicketMsg.ticket, @@ -942,14 +949,13 @@ func (hs *clientHandshakeState) readSessionTicket() error { func (hs *clientHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } copy(out, finished.verifyData) @@ -959,14 +965,16 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { // verifyServerCertificate parses and verifies the provided chain, setting // c.verifiedChains and c.peerCertificates or sending the appropriate alert. func (c *Conn) verifyServerCertificate(certificates [][]byte) error { + activeHandles := make([]*activeCert, len(certificates)) certs := make([]*x509.Certificate, len(certificates)) for i, asn1Data := range certificates { - cert, err := x509.ParseCertificate(asn1Data) + cert, err := clientCertCache.newCert(asn1Data) if err != nil { c.sendAlert(alertBadCertificate) return errors.New("tls: failed to parse certificate from server: " + err.Error()) } - certs[i] = cert + activeHandles[i] = cert + certs[i] = cert.cert } if !c.config.InsecureSkipVerify { @@ -976,6 +984,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { DNSName: c.config.ServerName, Intermediates: x509.NewCertPool(), } + for _, cert := range certs[1:] { opts.Intermediates.AddCert(cert) } @@ -983,7 +992,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { c.verifiedChains, err = certs[0].Verify(opts) if err != nil { c.sendAlert(alertBadCertificate) - return err + return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} } } @@ -995,6 +1004,7 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) } + c.activeCertHandles = activeHandles c.peerCertificates = certs if c.config.VerifyPeerCertificate != nil { diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go similarity index 89% rename from vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go rename to vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go index 7ff7d4c4072..8b7f0170082 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_client_tls13.go +++ b/vendor/github.com/quic-go/qtls-go1-20/handshake_client_tls13.go @@ -8,13 +8,13 @@ import ( "bytes" "context" "crypto" + "crypto/ecdh" "crypto/hmac" "crypto/rsa" "encoding/binary" "errors" "fmt" "hash" - "sync/atomic" "time" circlKem "github.com/cloudflare/circl/kem" @@ -22,11 +22,10 @@ import ( ) type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg keySharePrivate clientKeySharePrivate session *clientSessionState @@ -42,13 +41,17 @@ type clientHandshakeStateTLS13 struct { trafficSecret []byte // client_application_traffic_secret_0 } -// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and, +// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and, // optionally, hs.session, hs.earlySecret and hs.binderKey to be set. func (hs *clientHandshakeStateTLS13) handshake() error { c := hs.c startTime := time.Now() + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + // The server must not select TLS 1.3 in a renegotiation. See RFC 8446, // sections 4.1.2 and 4.1.3. if c.handshakes > 0 { @@ -66,7 +69,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) + + if err := transcriptMsg(hs.hello, hs.transcript); err != nil { + return err + } if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { if err := hs.sendDummyChangeCipherSpec(); err != nil { @@ -77,7 +83,9 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } } - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } c.buffering = true if err := hs.processServerHello(); err != nil { @@ -116,7 +124,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { kex: hs.serverHello.serverShare.group, }) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) c.updateConnectionState() return nil } @@ -184,8 +192,7 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } // processHelloRetryRequest handles the HRR in hs.serverHello, modifies and @@ -202,7 +209,9 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } // The only HelloRetryRequest extensions we support are key_share and // cookie, and clients must abort the handshake if the HRR would not result @@ -256,17 +265,17 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.keySharePrivate = sk hs.hello.keyShares = []keyShare{{group: curveID, data: packedPk}} } else { - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(c.config.rand(), curveID) + key, err := generateECDHEKey(c.config.rand(), curveID) if err != nil { c.sendAlert(alertInternalError) return err } - hs.keySharePrivate = params - hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + hs.keySharePrivate = key + hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } } @@ -284,10 +293,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { transcript := hs.suite.hash.New() transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } + helloBytes, err := hs.hello.marshalWithoutBinders() + if err != nil { + return err + } + transcript.Write(helloBytes) pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) + if err := hs.hello.updateBinders(pskBinders); err != nil { + return err + } } else { // Server selected a cipher suite incompatible with the PSK. hs.hello.pskIdentities = nil @@ -299,13 +316,12 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.extraConfig.Rejected0RTT() } hs.hello.earlyData = false // disable 0-RTT - - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -385,10 +401,13 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c var sharedKey []byte - if params, ok := hs.keySharePrivate.(ecdheParameters); ok { - sharedKey = params.SharedKey(hs.serverHello.serverShare.data) + var err error + if key, ok := hs.keySharePrivate.(*ecdh.PrivateKey); ok { + peerKey, err := key.Curve().NewPublicKey(hs.serverHello.serverShare.data) + if err == nil { + sharedKey, _ = key.ECDH(peerKey) + } } else if sk, ok := hs.keySharePrivate.(circlKem.PrivateKey); ok { - var err error sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) if err != nil { c.sendAlert(alertIllegalParameter) @@ -405,6 +424,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if !hs.usingPSK { earlySecret = hs.suite.extract(nil, nil) } + handshakeSecret := hs.suite.extract(sharedKey, hs.suite.deriveSecret(earlySecret, "derived", nil)) @@ -417,7 +437,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret) c.in.setTrafficSecret(hs.suite, serverSecret) - err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) + err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) if err != nil { c.sendAlert(alertInternalError) return err @@ -437,7 +457,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { func (hs *clientHandshakeStateTLS13) readServerParameters() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -455,7 +475,6 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil { hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions) } - hs.transcript.Write(encryptedExtensions.marshal()) if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { c.sendAlert(alertUnsupportedExtension) @@ -491,18 +510,16 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return nil } - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } certReq, ok := msg.(*certificateRequestMsgTLS13) if ok { - hs.transcript.Write(certReq.marshal()) - hs.certReq = certReq - msg, err = c.readHandshake() + msg, err = c.readHandshake(hs.transcript) if err != nil { return err } @@ -517,7 +534,6 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { c.sendAlert(alertDecodeError) return errors.New("tls: received empty certificates message") } - hs.transcript.Write(certMsg.marshal()) c.scts = certMsg.certificate.SignedCertificateTimestamps c.ocspResponse = certMsg.certificate.OCSPStaple @@ -526,7 +542,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return err } - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -538,7 +557,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { c.sendAlert(alertIllegalParameter) return errors.New("tls: certificate used with invalid signature algorithm") } @@ -557,7 +576,9 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return errors.New("tls: invalid signature by the server certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } return nil } @@ -565,7 +586,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { func (hs *clientHandshakeStateTLS13) readServerFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -582,7 +606,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { return errors.New("tls: invalid server finished hash") } - hs.transcript.Write(finished.marshal()) + if err := transcriptMsg(finished, hs.transcript); err != nil { + return err + } // Derive secrets that take context through the server Finished. @@ -632,8 +658,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -670,8 +695,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -685,8 +709,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go similarity index 75% rename from vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go rename to vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go index 5f87d4b81b9..c69fcefda25 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_messages.go +++ b/vendor/github.com/quic-go/qtls-go1-20/handshake_messages.go @@ -5,6 +5,7 @@ package qtls import ( + "errors" "fmt" "strings" @@ -95,9 +96,187 @@ type clientHelloMsg struct { additionalExtensions []Extension } -func (m *clientHelloMsg) marshal() []byte { +func (m *clientHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + exts.AddUint16(extensionServerName) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(0) // name_type = host_name + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + exts.AddUint16(extensionStatusRequest) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(1) // status_type = ocsp + exts.AddUint16(0) // empty responder_id_list + exts.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + exts.AddUint16(extensionSupportedCurves) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + exts.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + exts.AddUint16(extensionSessionTicket) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + exts.AddUint16(extensionSignatureAlgorithms) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + exts.AddUint16(extensionSignatureAlgorithmsCert) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + exts.AddUint16(extensionSCT) + exts.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + exts.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, ks := range m.keyShares { + exts.AddUint16(uint16(ks.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + exts.AddUint16(extensionEarlyData) + exts.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + exts.AddUint16(extensionPSKModes) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.pskModes) + }) + }) + } + for _, ext := range m.additionalExtensions { + exts.AddUint16(ext.Type) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ext.Data) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(psk.label) + }) + exts.AddUint32(psk.obfuscatedTicketAge) + } + }) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(binder) + }) + } + }) + }) + } + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -117,225 +296,53 @@ func (m *clientHelloMsg) marshal() []byte { b.AddBytes(m.compressionMethods) }) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - for _, ext := range m.additionalExtensions { - b.AddUint16(ext.Type) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ext.Data) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } // marshalWithoutBinders returns the ClientHello through the // PreSharedKeyExtension.identities field, according to RFC 8446, Section // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { +func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen := 2 // uint16 length prefix for _, binder := range m.pskBinders { bindersLen += 1 // uint8 length prefix bindersLen += len(binder) } - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] + fullMessage, err := m.marshal() + if err != nil { + return nil, err + } + return fullMessage[:len(fullMessage)-bindersLen], nil } // updateBinders updates the m.pskBinders field, if necessary updating the // cached marshaled representation. The supplied binders must have the same // length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } for i := range m.pskBinders { if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } } m.pskBinders = pskBinders if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) + helloBytes, err := m.marshalWithoutBinders() + if err != nil { + return err + } + lenWithoutBinders := len(helloBytes) b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, binder := range m.pskBinders { @@ -345,9 +352,11 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { } }) if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - panic("tls: internal error: failed to update binders") + return errors.New("tls: internal error: failed to update binders") } } + + return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { @@ -391,15 +400,21 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } + seenExts := make(map[uint16]bool) for !extensions.Empty() { - var ext uint16 + var extension uint16 var extData cryptobyte.String - if !extensions.ReadUint16(&ext) || + if !extensions.ReadUint16(&extension) || !extensions.ReadUint16LengthPrefixed(&extData) { return false } - switch ext { + if seenExts[extension] { + return false + } + seenExts[extension] = true + + switch extension { case extensionServerName: // RFC 6066, Section 3 var nameList cryptobyte.String @@ -583,7 +598,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.pskBinders = append(m.pskBinders, binder) } default: - m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData}) + m.additionalExtensions = append(m.additionalExtensions, Extension{Type: extension, Data: extData}) continue } @@ -619,9 +634,98 @@ type serverHelloMsg struct { selectedGroup CurveID } -func (m *serverHelloMsg) marshal() []byte { +func (m *serverHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if m.ocspStapling { + exts.AddUint16(extensionStatusRequest) + exts.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + exts.AddUint16(extensionSessionTicket) + exts.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + exts.AddUint16(extensionSCT) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sct := range m.scts { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.serverShare.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -635,104 +739,15 @@ func (m *serverHelloMsg) marshal() []byte { b.AddUint16(m.cipherSuite) b.AddUint8(m.compressionMethod) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } func (m *serverHelloMsg) unmarshal(data []byte) bool { @@ -757,6 +772,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } + seenExts := make(map[uint16]bool) for !extensions.Empty() { var extension uint16 var extData cryptobyte.String @@ -765,6 +781,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } + if seenExts[extension] { + return false + } + seenExts[extension] = true + switch extension { case extensionStatusRequest: m.ocspStapling = true @@ -853,9 +874,9 @@ type encryptedExtensionsMsg struct { additionalExtensions []Extension } -func (m *encryptedExtensionsMsg) marshal() []byte { +func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -886,8 +907,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { @@ -937,10 +959,10 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { type endOfEarlyDataMsg struct{} -func (m *endOfEarlyDataMsg) marshal() []byte { +func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeEndOfEarlyData - return x + return x, nil } func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { @@ -952,9 +974,9 @@ type keyUpdateMsg struct { updateRequested bool } -func (m *keyUpdateMsg) marshal() []byte { +func (m *keyUpdateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -967,8 +989,9 @@ func (m *keyUpdateMsg) marshal() []byte { } }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *keyUpdateMsg) unmarshal(data []byte) bool { @@ -1000,9 +1023,9 @@ type newSessionTicketMsgTLS13 struct { maxEarlyData uint32 } -func (m *newSessionTicketMsgTLS13) marshal() []byte { +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1027,8 +1050,9 @@ func (m *newSessionTicketMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { @@ -1081,9 +1105,9 @@ type certificateRequestMsgTLS13 struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsgTLS13) marshal() []byte { +func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1142,8 +1166,9 @@ func (m *certificateRequestMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { @@ -1227,9 +1252,9 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) marshal() (x []byte) { +func (m *certificateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var i int @@ -1238,7 +1263,7 @@ func (m *certificateMsg) marshal() (x []byte) { } length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificate x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1259,7 +1284,7 @@ func (m *certificateMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1306,9 +1331,9 @@ type certificateMsgTLS13 struct { scts bool } -func (m *certificateMsgTLS13) marshal() []byte { +func (m *certificateMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1326,8 +1351,9 @@ func (m *certificateMsgTLS13) marshal() []byte { marshalCertificate(b, certificate) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1450,9 +1476,9 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) marshal() []byte { +func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.key) x := make([]byte, length+4) @@ -1463,7 +1489,7 @@ func (m *serverKeyExchangeMsg) marshal() []byte { copy(x[4:], m.key) m.raw = x - return x + return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1480,9 +1506,9 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) marshal() []byte { +func (m *certificateStatusMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1494,8 +1520,9 @@ func (m *certificateStatusMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateStatusMsg) unmarshal(data []byte) bool { @@ -1514,10 +1541,10 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) marshal() []byte { +func (m *serverHelloDoneMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeServerHelloDone - return x + return x, nil } func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { @@ -1529,9 +1556,9 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) marshal() []byte { +func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.ciphertext) x := make([]byte, length+4) @@ -1542,7 +1569,7 @@ func (m *clientKeyExchangeMsg) marshal() []byte { copy(x[4:], m.ciphertext) m.raw = x - return x + return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1563,9 +1590,9 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) marshal() []byte { +func (m *finishedMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1574,8 +1601,9 @@ func (m *finishedMsg) marshal() []byte { b.AddBytes(m.verifyData) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *finishedMsg) unmarshal(data []byte) bool { @@ -1597,9 +1625,9 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) marshal() (x []byte) { +func (m *certificateRequestMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 4346, Section 7.4.4. @@ -1614,7 +1642,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { length += 2 + 2*len(m.supportedSignatureAlgorithms) } - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificateRequest x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1649,7 +1677,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { @@ -1735,9 +1763,9 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) marshal() (x []byte) { +func (m *certificateVerifyMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1751,8 +1779,9 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { @@ -1775,15 +1804,15 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) marshal() (x []byte) { +func (m *newSessionTicketMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeNewSessionTicket x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1794,7 +1823,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) { m.raw = x - return + return m.raw, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { @@ -1822,10 +1851,25 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool { type helloRequestMsg struct { } -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} +func (*helloRequestMsg) marshal() ([]byte, error) { + return []byte{typeHelloRequest, 0, 0, 0}, nil } func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } + +type transcriptHash interface { + Write([]byte) (int, error) +} + +// transcriptMsg is a helper used to marshal and hash messages which typically +// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + data, err := msg.marshal() + if err != nil { + return err + } + h.Write(data) + return nil +} diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_server.go similarity index 92% rename from vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go rename to vendor/github.com/quic-go/qtls-go1-20/handshake_server.go index a6519d7f580..3443e85a15e 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server.go +++ b/vendor/github.com/quic-go/qtls-go1-20/handshake_server.go @@ -16,7 +16,6 @@ import ( "fmt" "hash" "io" - "sync/atomic" "time" ) @@ -130,7 +129,7 @@ func (hs *serverHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) c.updateConnectionState() return nil @@ -138,7 +137,9 @@ func (hs *serverHandshakeState) handshake() error { // readClientHello reads a ClientHello message and selects the protocol version. func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - msg, err := c.readHandshake() + // clientHelloMsg is included in the transcript, but we haven't initialized + // it yet. The respective handshake functions will record it themselves. + msg, err := c.readHandshake(nil) if err != nil { return nil, err } @@ -494,9 +495,10 @@ func (hs *serverHandshakeState) doResumeHandshake() error { hs.hello.ticketSupported = hs.sessionState.usedOldKey hs.finishedHash = newFinishedHash(c.vers, hs.suite) hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } @@ -534,24 +536,23 @@ func (hs *serverHandshakeState) doFullHandshake() error { // certificates won't be used. hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } certMsg := new(certificateMsg) certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { return err } } @@ -563,8 +564,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { return err } if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err } } @@ -579,7 +579,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { } if c.vers >= VersionTLS12 { certReq.hasSignatureAlgorithm = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } // An empty list of certificateAuthorities signals to @@ -590,15 +590,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.finishedHash.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil { return err } } helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { return err } @@ -608,7 +606,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { var pub crypto.PublicKey // public key for client auth, if any - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -621,7 +619,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) if err := c.processCertsFromClient(Certificate{ Certificate: certMsg.certificates, @@ -632,7 +629,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { pub = c.peerCertificates[0].PublicKey } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -650,7 +647,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(ckx, msg) } - hs.finishedHash.Write(ckx.marshal()) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) if err != nil { @@ -670,7 +666,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { // to the client's certificate. This allows us to verify that the client is in // possession of the private key of the certificate. if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -699,13 +698,15 @@ func (hs *serverHandshakeState) doFullHandshake() error { } } - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret) + signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { c.sendAlert(alertDecryptError) return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.finishedHash.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { + return err + } } hs.finishedHash.discardHandshakeBuffer() @@ -745,7 +746,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -762,7 +766,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash.Write(clientFinished.marshal()) + if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -796,14 +803,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { masterSecret: hs.masterSecret, certificates: certsFromClient, } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return err + } + m.ticket, err = c.encryptTicket(stateBytes) if err != nil { return err } - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { return err } @@ -813,14 +822,13 @@ func (hs *serverHandshakeState) sendSessionTicket() error { func (hs *serverHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } @@ -863,7 +871,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { chains, err := certs[0].Verify(opts) if err != nil { c.sendAlert(alertBadCertificate) - return errors.New("tls: failed to verify client certificate: " + err.Error()) + return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err} } c.verifiedChains = chains diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go b/vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go similarity index 90% rename from vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go rename to vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go index fa89c97f594..4eb0c73d4f2 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/handshake_server_tls13.go +++ b/vendor/github.com/quic-go/qtls-go1-20/handshake_server_tls13.go @@ -14,7 +14,6 @@ import ( "fmt" "hash" "io" - "sync/atomic" "time" ) @@ -49,6 +48,10 @@ func (hs *serverHandshakeStateTLS13) handshake() error { startTime := time.Now() + if needFIPS() { + return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode") + } + // For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2. if err := hs.processClientHello(); err != nil { return err @@ -90,7 +93,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error { kex: hs.hello.serverShare.group, }) - atomic.StoreUint32(&c.handshakeStatus, 1) + c.isHandshakeComplete.Store(true) c.updateConnectionState() return nil } @@ -152,27 +155,14 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { hs.hello.sessionId = hs.clientHello.sessionId hs.hello.compressionMethod = compressionNone - if hs.suite == nil { - var preferenceList []uint16 - for _, suiteID := range c.config.CipherSuites { - for _, suite := range cipherSuitesTLS13 { - if suite.id == suiteID { - preferenceList = append(preferenceList, suiteID) - break - } - } - } - if len(preferenceList) == 0 { - preferenceList = defaultCipherSuitesTLS13 - if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { - preferenceList = defaultCipherSuitesTLS13NoAES - } - } - for _, suiteID := range preferenceList { - hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) - if hs.suite != nil { - break - } + preferenceList := defaultCipherSuitesTLS13 + if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) { + preferenceList = defaultCipherSuitesTLS13NoAES + } + for _, suiteID := range preferenceList { + hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID) + if hs.suite != nil { + break } } if hs.suite == nil { @@ -217,7 +207,7 @@ GroupSelection: clientKeyShare = &hs.clientHello.keyShares[0] } - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && curveIdToCirclScheme(selectedGroup) == nil && !ok { + if _, ok := curveForCurveID(selectedGroup); curveIdToCirclScheme(selectedGroup) == nil && !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } @@ -230,13 +220,16 @@ GroupSelection: hs.hello.serverShare = keyShare{group: selectedGroup, data: ct} hs.sharedKey = ss } else { - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + key, err := generateECDHEKey(c.config.rand(), selectedGroup) if err != nil { c.sendAlert(alertInternalError) return err } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) + hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} + peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data) + if err == nil { + hs.sharedKey, _ = key.ECDH(peerKey) + } } if hs.sharedKey == nil { c.sendAlert(alertIllegalParameter) @@ -349,7 +342,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { c.sendAlert(alertInternalError) return errors.New("tls: internal error: failed to clone hash") } - transcript.Write(hs.clientHello.marshalWithoutBinders()) + clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + transcript.Write(clientHelloBytes) pskBinder := hs.suite.finishedHash(binderKey, transcript) if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { c.sendAlert(alertDecryptError) @@ -362,7 +360,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { } h := cloneHash(hs.transcript, hs.suite.hash) - h.Write(hs.clientHello.marshal()) + clientHelloWithBindersBytes, err := hs.clientHello.marshal() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + h.Write(clientHelloWithBindersBytes) if hs.encryptedExtensions.earlyData { clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h) c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret) @@ -451,8 +454,7 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { @@ -462,7 +464,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } chHash := hs.transcript.Sum(nil) hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) @@ -478,8 +482,7 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) selectedGroup: selectedGroup, } - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { return err } @@ -487,7 +490,8 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) return err } - msg, err := c.readHandshake() + // clientHelloMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -583,9 +587,10 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { func (hs *serverHandshakeStateTLS13) sendServerParameters() error { c := hs.c - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } @@ -628,8 +633,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions) } - hs.transcript.Write(hs.encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.encryptedExtensions.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil { return err } @@ -653,13 +657,12 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq := new(certificateRequestMsgTLS13) certReq.ocspStapling = true certReq.scts = true - certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { return err } } @@ -670,8 +673,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -702,8 +704,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -717,8 +718,7 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } @@ -780,7 +780,9 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { finishedMsg := &finishedMsg{ verifyData: hs.clientFinished, } - hs.transcript.Write(finishedMsg.marshal()) + if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { + return err + } if !hs.shouldSendSessionTickets() { return nil @@ -801,7 +803,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { return err } - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(m, nil); err != nil { return err } @@ -826,7 +828,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { // If we requested a client certificate, then the client must send a // certificate message. If it's empty, no CertificateVerify is sent. - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -836,7 +838,6 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.transcript.Write(certMsg.marshal()) if err := c.processCertsFromClient(certMsg.certificate); err != nil { return err @@ -850,7 +851,10 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -862,7 +866,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } // See RFC 8446, Section 4.4.3. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) { c.sendAlert(alertIllegalParameter) return errors.New("tls: client certificate used with invalid signature algorithm") } @@ -881,7 +885,9 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } } // If we waited until the client certificates to send session tickets, we @@ -896,7 +902,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { func (hs *serverHandshakeStateTLS13) readClientFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } diff --git a/vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go b/vendor/github.com/quic-go/qtls-go1-20/key_agreement.go similarity index 94% rename from vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go rename to vendor/github.com/quic-go/qtls-go1-20/key_agreement.go index 8a6fd69233a..a2b78b56485 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-19/key_agreement.go +++ b/vendor/github.com/quic-go/qtls-go1-20/key_agreement.go @@ -6,6 +6,7 @@ package qtls import ( "crypto" + "crypto/ecdh" "crypto/md5" "crypto/rsa" "crypto/sha1" @@ -157,7 +158,7 @@ func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint1 type ecdheKeyAgreement struct { version uint16 isRSA bool - params ecdheParameters + key *ecdh.PrivateKey // ckx and preMasterSecret are generated in processServerKeyExchange // and returned in generateClientKeyExchange. @@ -177,18 +178,18 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Cer if curveID == 0 { return nil, errors.New("tls: no supported elliptic curves offered") } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return nil, errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), curveID) if err != nil { return nil, err } - ka.params = params + ka.key = key // See RFC 4492, Section 5.4. - ecdhePublic := params.PublicKey() + ecdhePublic := key.PublicKey().Bytes() serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) serverECDHEParams[0] = 3 // named curve serverECDHEParams[1] = byte(curveID >> 8) @@ -259,8 +260,12 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Cert return nil, errClientKeyExchange } - preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:]) - if preMasterSecret == nil { + peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:]) + if err != nil { + return nil, errClientKeyExchange + } + preMasterSecret, err := ka.key.ECDH(peerKey) + if err != nil { return nil, errClientKeyExchange } @@ -288,22 +293,26 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHell return errServerKeyExchange } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + if _, ok := curveForCurveID(curveID); !ok { return errors.New("tls: server selected unsupported curve") } - params, err := generateECDHEParameters(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), curveID) if err != nil { return err } - ka.params = params + ka.key = key - ka.preMasterSecret = params.SharedKey(publicKey) - if ka.preMasterSecret == nil { + peerKey, err := key.Curve().NewPublicKey(publicKey) + if err != nil { + return errServerKeyExchange + } + ka.preMasterSecret, err = key.ECDH(peerKey) + if err != nil { return errServerKeyExchange } - ourPublicKey := params.PublicKey() + ourPublicKey := key.PublicKey().Bytes() ka.ckx = new(clientKeyExchangeMsg) ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey)) ka.ckx.ciphertext[0] = byte(len(ourPublicKey)) diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go b/vendor/github.com/quic-go/qtls-go1-20/key_schedule.go similarity index 63% rename from vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go rename to vendor/github.com/quic-go/qtls-go1-20/key_schedule.go index da13904a6e8..c410a3e8277 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-18/key_schedule.go +++ b/vendor/github.com/quic-go/qtls-go1-20/key_schedule.go @@ -5,15 +5,14 @@ package qtls import ( - "crypto/elliptic" + "crypto/ecdh" "crypto/hmac" "errors" + "fmt" "hash" "io" - "math/big" "golang.org/x/crypto/cryptobyte" - "golang.org/x/crypto/curve25519" "golang.org/x/crypto/hkdf" ) @@ -42,8 +41,24 @@ func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []by hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(context) }) + hkdfLabelBytes, err := hkdfLabel.Bytes() + if err != nil { + // Rather than calling BytesOrPanic, we explicitly handle this error, in + // order to provide a reasonable error message. It should be basically + // impossible for this to panic, and routing errors back through the + // tree rooted in this function is quite painful. The labels are fixed + // size, and the context is either a fixed-length computed hash, or + // parsed from a field which has the same length limitation. As such, an + // error here is likely to only be caused during development. + // + // NOTE: another reasonable approach here might be to return a + // randomized slice if we encounter an error, which would break the + // connection, but avoid panicking. This would perhaps be safer but + // significantly more confusing to users. + panic(fmt.Errorf("failed to construct HKDF label: %s", err)) + } out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) if err != nil || n != length { panic("tls: HKDF-Expand-Label invocation failed unexpectedly") } @@ -101,99 +116,43 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript } } -// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519, +// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman // according to RFC 8446, Section 4.2.8.2. -type ecdheParameters interface { - CurveID() CurveID - PublicKey() []byte - SharedKey(peerPublicKey []byte) []byte -} - -func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) { - if curveID == X25519 { - privateKey := make([]byte, curve25519.ScalarSize) - if _, err := io.ReadFull(rand, privateKey); err != nil { - return nil, err - } - publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint) - if err != nil { - return nil, err - } - return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil - } - +func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { curve, ok := curveForCurveID(curveID) if !ok { return nil, errors.New("tls: internal error: unsupported curve") } - p := &nistParameters{curveID: curveID} - var err error - p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand) - if err != nil { - return nil, err - } - return p, nil + return curve.GenerateKey(rand) } -func curveForCurveID(id CurveID) (elliptic.Curve, bool) { +func curveForCurveID(id CurveID) (ecdh.Curve, bool) { switch id { + case X25519: + return ecdh.X25519(), true case CurveP256: - return elliptic.P256(), true + return ecdh.P256(), true case CurveP384: - return elliptic.P384(), true + return ecdh.P384(), true case CurveP521: - return elliptic.P521(), true + return ecdh.P521(), true default: return nil, false } } -type nistParameters struct { - privateKey []byte - x, y *big.Int // public key - curveID CurveID -} - -func (p *nistParameters) CurveID() CurveID { - return p.curveID -} - -func (p *nistParameters) PublicKey() []byte { - curve, _ := curveForCurveID(p.curveID) - return elliptic.Marshal(curve, p.x, p.y) -} - -func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte { - curve, _ := curveForCurveID(p.curveID) - // Unmarshal also checks whether the given point is on the curve. - x, y := elliptic.Unmarshal(curve, peerPublicKey) - if x == nil { - return nil - } - - xShared, _ := curve.ScalarMult(x, y, p.privateKey) - sharedKey := make([]byte, (curve.Params().BitSize+7)/8) - return xShared.FillBytes(sharedKey) -} - -type x25519Parameters struct { - privateKey []byte - publicKey []byte -} - -func (p *x25519Parameters) CurveID() CurveID { - return X25519 -} - -func (p *x25519Parameters) PublicKey() []byte { - return p.publicKey[:] -} - -func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte { - sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey) - if err != nil { - return nil +func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) { + switch curve { + case ecdh.X25519(): + return X25519, true + case ecdh.P256(): + return CurveP256, true + case ecdh.P384(): + return CurveP384, true + case ecdh.P521(): + return CurveP521, true + default: + return 0, false } - return sharedKey } diff --git a/vendor/github.com/quic-go/qtls-go1-20/notboring.go b/vendor/github.com/quic-go/qtls-go1-20/notboring.go new file mode 100644 index 00000000000..f292e4f028e --- /dev/null +++ b/vendor/github.com/quic-go/qtls-go1-20/notboring.go @@ -0,0 +1,18 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package qtls + +func needFIPS() bool { return false } + +func supportedSignatureAlgorithms() []SignatureScheme { + return defaultSupportedSignatureAlgorithms +} + +func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") } +func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") } +func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") } +func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") } + +var fipsSupportedSignatureAlgorithms []SignatureScheme diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/prf.go b/vendor/github.com/quic-go/qtls-go1-20/prf.go similarity index 99% rename from vendor/github.com/marten-seemann/qtls-go1-17/prf.go rename to vendor/github.com/quic-go/qtls-go1-20/prf.go index 9eb0221a0c0..1471289182b 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-17/prf.go +++ b/vendor/github.com/quic-go/qtls-go1-20/prf.go @@ -215,7 +215,7 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { // hashForClientCertificate returns the handshake messages so far, pre-hashed if // necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte { +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") } diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/ticket.go b/vendor/github.com/quic-go/qtls-go1-20/ticket.go similarity index 95% rename from vendor/github.com/marten-seemann/qtls-go1-17/ticket.go rename to vendor/github.com/quic-go/qtls-go1-20/ticket.go index 81e8a52eac6..1b9289c2fe2 100644 --- a/vendor/github.com/marten-seemann/qtls-go1-17/ticket.go +++ b/vendor/github.com/quic-go/qtls-go1-20/ticket.go @@ -34,7 +34,7 @@ type sessionState struct { usedOldKey bool } -func (m *sessionState) marshal() []byte { +func (m *sessionState) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(m.vers) b.AddUint16(m.cipherSuite) @@ -49,7 +49,7 @@ func (m *sessionState) marshal() []byte { }) } }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionState) unmarshal(data []byte) bool { @@ -94,7 +94,7 @@ type sessionStateTLS13 struct { appData []byte } -func (m *sessionStateTLS13) marshal() []byte { +func (m *sessionStateTLS13) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(VersionTLS13) b.AddUint8(2) // revision @@ -111,7 +111,7 @@ func (m *sessionStateTLS13) marshal() []byte { b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(m.appData) }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionStateTLS13) unmarshal(data []byte) bool { @@ -227,8 +227,11 @@ func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, e if c.extraConfig != nil { state.maxEarlyData = c.extraConfig.MaxEarlyData } - var err error - m.label, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return nil, err + } + m.label, err = c.encryptTicket(stateBytes) if err != nil { return nil, err } @@ -259,7 +262,7 @@ func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, e // The ticket may be nil if config.SessionTicketsDisabled is set, // or if the client isn't able to receive session tickets. func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { - if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { + if c.isClient || !c.isHandshakeComplete.Load() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil { return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.") } if c.config.SessionTicketsDisabled { @@ -270,5 +273,5 @@ func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) { if err != nil { return nil, err } - return m.marshal(), nil + return m.marshal() } diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/tls.go b/vendor/github.com/quic-go/qtls-go1-20/tls.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/tls.go rename to vendor/github.com/quic-go/qtls-go1-20/tls.go diff --git a/vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go b/vendor/github.com/quic-go/qtls-go1-20/unsafe.go similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-17/unsafe.go rename to vendor/github.com/quic-go/qtls-go1-20/unsafe.go diff --git a/vendor/github.com/lucas-clemente/quic-go/.gitignore b/vendor/github.com/quic-go/quic-go/.gitignore similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/.gitignore rename to vendor/github.com/quic-go/quic-go/.gitignore diff --git a/vendor/github.com/lucas-clemente/quic-go/.golangci.yml b/vendor/github.com/quic-go/quic-go/.golangci.yml similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/.golangci.yml rename to vendor/github.com/quic-go/quic-go/.golangci.yml index 2589c053892..7820be8c975 100644 --- a/vendor/github.com/lucas-clemente/quic-go/.golangci.yml +++ b/vendor/github.com/quic-go/quic-go/.golangci.yml @@ -1,14 +1,15 @@ run: - skip-files: - - internal/qtls/structs_equal_test.go - linters-settings: depguard: type: blacklist packages: - github.com/marten-seemann/qtls + - github.com/quic-go/qtls-go1-19 + - github.com/quic-go/qtls-go1-20 packages-with-error-message: - github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls" + - github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls" + - github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls" misspell: ignore-words: - ect @@ -17,7 +18,6 @@ linters: disable-all: true enable: - asciicheck - - deadcode - depguard - exhaustive - exportloopref @@ -30,11 +30,9 @@ linters: - prealloc - staticcheck - stylecheck - - structcheck - unconvert - unparam - unused - - varcheck - vet issues: diff --git a/vendor/github.com/lucas-clemente/quic-go/Changelog.md b/vendor/github.com/quic-go/quic-go/Changelog.md similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/Changelog.md rename to vendor/github.com/quic-go/quic-go/Changelog.md index c1c3323273c..82df5fb2457 100644 --- a/vendor/github.com/lucas-clemente/quic-go/Changelog.md +++ b/vendor/github.com/quic-go/quic-go/Changelog.md @@ -101,8 +101,8 @@ - Add a `quic.Config` option to configure keep-alive - Rename the STK to Cookie - Implement `net.Conn`-style deadlines for streams -- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. -- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. +- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/quic-go/quic-go) for details. +- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/quic-go/quic-go/wiki/Logging) for more details. - Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` - Changed `h2quic.Server.Serve()` to accept a `net.PacketConn` - Drop support for Go 1.7 and 1.8. diff --git a/vendor/github.com/lucas-clemente/quic-go/LICENSE b/vendor/github.com/quic-go/quic-go/LICENSE similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/LICENSE rename to vendor/github.com/quic-go/quic-go/LICENSE diff --git a/vendor/github.com/quic-go/quic-go/README.md b/vendor/github.com/quic-go/quic-go/README.md new file mode 100644 index 00000000000..e518f1e4304 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/README.md @@ -0,0 +1,64 @@ +# A QUIC implementation in pure Go + + + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go) +[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) + +quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU + Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). + +In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. + +## Guides + +*We currently support Go 1.19.x and Go 1.20.x* + +Running tests: + + go test ./... + +### QUIC without HTTP/3 + +Take a look at [this echo example](example/echo/echo.go). + +## Usage + +### As a server + +See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: + +```go +http.Handle("/", http.FileServer(http.Dir(wwwDir))) +http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) +``` + +### As a client + +See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`. + +```go +http.Client{ + Transport: &http3.RoundTripper{}, +} +``` + +## Projects using quic-go + +| Project | Description | Stars | +|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| +| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | +| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | +| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | +| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | +| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | +| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | +| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | +| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | +| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | +| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | +| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | + +## Contributing + +We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment. diff --git a/vendor/github.com/quic-go/quic-go/SECURITY.md b/vendor/github.com/quic-go/quic-go/SECURITY.md new file mode 100644 index 00000000000..c24c08f8630 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/SECURITY.md @@ -0,0 +1,19 @@ +# Security Policy + +quic-go still in development. This means that there may be problems in our protocols, +or there may be mistakes in our implementations. +We take security vulnerabilities very seriously. If you discover a security issue, +please bring it to our attention right away! + +## Reporting a Vulnerability + +If you find a vulnerability that may affect live deployments -- for example, by exposing +a remote execution exploit -- please [**report privately**](https://github.com/quic-go/quic-go/security/advisories/new). +Please **DO NOT file a public issue**. + +If the issue is an implementation weakness that cannot be immediately exploited or +something not yet deployed, just discuss it openly. + +## Reporting a non security bug + +For non-security bugs, please simply file a GitHub [issue](https://github.com/quic-go/quic-go/issues/new). diff --git a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go b/vendor/github.com/quic-go/quic-go/buffer_pool.go similarity index 96% rename from vendor/github.com/lucas-clemente/quic-go/buffer_pool.go rename to vendor/github.com/quic-go/quic-go/buffer_pool.go index c0b7067da59..f6745b0803d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/buffer_pool.go +++ b/vendor/github.com/quic-go/quic-go/buffer_pool.go @@ -3,7 +3,7 @@ package quic import ( "sync" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) type packetBuffer struct { diff --git a/vendor/github.com/quic-go/quic-go/client.go b/vendor/github.com/quic-go/quic-go/client.go new file mode 100644 index 00000000000..3d0bca3dfb6 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/client.go @@ -0,0 +1,252 @@ +package quic + +import ( + "context" + "crypto/tls" + "errors" + "net" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type client struct { + sendConn sendConn + + use0RTT bool + + packetHandlers packetHandlerManager + onClose func() + + tlsConf *tls.Config + config *Config + + connIDGenerator ConnectionIDGenerator + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID + + initialPacketNumber protocol.PacketNumber + hasNegotiatedVersion bool + version protocol.VersionNumber + + handshakeChan chan struct{} + + conn quicConn + + tracer logging.ConnectionTracer + tracingID uint64 + logger utils.Logger +} + +// make it possible to mock connection ID for initial generation in the tests +var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial + +// DialAddr establishes a new QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + return dl.Dial(ctx, udpAddr, tlsConf, conf) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. +func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. +// See DialEarly for details. +func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// Dial establishes a new QUIC connection to a server using a net.PacketConn. If +// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. +// The tls.Config must define an application protocol (using NextProtos). +func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.Dial(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + return &Transport{ + Conn: c, + createdConn: createdPacketConn, + isSingleUse: true, + }, nil +} + +func dial( + ctx context.Context, + conn sendConn, + connIDGenerator ConnectionIDGenerator, + packetHandlers packetHandlerManager, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, +) (quicConn, error) { + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + + c.tracingID = nextConnTracingID() + if c.config.Tracer != nil { + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) + } + if c.tracer != nil { + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) + } + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.conn, nil +} + +func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { + if tlsConf == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() + } + + srcConnID, err := connIDGenerator.GenerateConnectionID() + if err != nil { + return nil, err + } + destConnID, err := generateConnectionIDForInitial() + if err != nil { + return nil, err + } + c := &client{ + connIDGenerator: connIDGenerator, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), + } + return c, nil +} + +func (c *client) dial(ctx context.Context) error { + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + c.conn = newClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) + c.packetHandlers.Add(c.srcConnID, c.conn) + + errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) + go func() { + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return + } + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed + }() + + // only set when we're using 0-RTT + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if c.use0RTT { + earlyConnChan = c.conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + c.conn.shutdown() + return ctx.Err() + case err := <-errorChan: + return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) + case <-earlyConnChan: + // ready to send 0-RTT data + return nil + case <-c.conn.HandshakeComplete(): + // handshake successfully completed + return nil + } +} diff --git a/vendor/github.com/quic-go/quic-go/closed_conn.go b/vendor/github.com/quic-go/quic-go/closed_conn.go new file mode 100644 index 00000000000..73904b84688 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/closed_conn.go @@ -0,0 +1,64 @@ +package quic + +import ( + "math/bits" + "net" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" +) + +// A closedLocalConn is a connection that we closed locally. +// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, +// with an exponential backoff. +type closedLocalConn struct { + counter uint32 + perspective protocol.Perspective + logger utils.Logger + + sendPacket func(net.Addr, *packetInfo) +} + +var _ packetHandler = &closedLocalConn{} + +// newClosedLocalConn creates a new closedLocalConn and runs it. +func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { + return &closedLocalConn{ + sendPacket: sendPacket, + perspective: pers, + logger: logger, + } +} + +func (c *closedLocalConn) handlePacket(p *receivedPacket) { + c.counter++ + // exponential backoff + // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving + if bits.OnesCount32(c.counter) != 1 { + return + } + c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter) + c.sendPacket(p.remoteAddr, p.info) +} + +func (c *closedLocalConn) shutdown() {} +func (c *closedLocalConn) destroy(error) {} +func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } + +// A closedRemoteConn is a connection that was closed remotely. +// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. +// We can just ignore those packets. +type closedRemoteConn struct { + perspective protocol.Perspective +} + +var _ packetHandler = &closedRemoteConn{} + +func newClosedRemoteConn(pers protocol.Perspective) packetHandler { + return &closedRemoteConn{perspective: pers} +} + +func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) shutdown() {} +func (s *closedRemoteConn) destroy(error) {} +func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/vendor/github.com/lucas-clemente/quic-go/codecov.yml b/vendor/github.com/quic-go/quic-go/codecov.yml similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/codecov.yml rename to vendor/github.com/quic-go/quic-go/codecov.yml index ee9cfd3b38d..074d9832523 100644 --- a/vendor/github.com/lucas-clemente/quic-go/codecov.yml +++ b/vendor/github.com/quic-go/quic-go/codecov.yml @@ -12,6 +12,7 @@ coverage: - internal/utils/newconnectionid_linkedlist.go - internal/utils/packetinterval_linkedlist.go - internal/utils/linkedlist/linkedlist.go + - logging/null_tracer.go - fuzzing/ - metrics/ status: diff --git a/vendor/github.com/lucas-clemente/quic-go/config.go b/vendor/github.com/quic-go/quic-go/config.go similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/config.go rename to vendor/github.com/quic-go/quic-go/config.go index 884cff0062d..1808e9ef48b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/config.go +++ b/vendor/github.com/quic-go/quic-go/config.go @@ -2,11 +2,12 @@ package quic import ( "errors" + "fmt" + "net" "time" - "github.com/lucas-clemente/quic-go/internal/utils" - - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // Clone clones a Config @@ -16,7 +17,7 @@ func (c *Config) Clone() *Config { } func (c *Config) handshakeTimeout() time.Duration { - return utils.MaxDuration(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) + return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) } func validateConfig(config *Config) error { @@ -29,32 +30,34 @@ func validateConfig(config *Config) error { if config.MaxIncomingUniStreams > 1<<60 { return errors.New("invalid value for Config.MaxIncomingUniStreams") } + // check that all QUIC versions are actually supported + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return fmt.Errorf("invalid QUIC version: %s", v) + } + } return nil } // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config, protocol.DefaultConnectionIDLength) - if config.AcceptToken == nil { - config.AcceptToken = defaultAcceptToken + config = populateConfig(config) + if config.MaxTokenAge == 0 { + config.MaxTokenAge = protocol.TokenValidity } - return config -} - -// populateClientConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - var defaultConnIdLen = protocol.DefaultConnectionIDLength - if createdPacketConn { - defaultConnIdLen = 0 + if config.MaxRetryTokenAge == 0 { + config.MaxRetryTokenAge = protocol.RetryTokenValidity + } + if config.RequireAddressValidation == nil { + config.RequireAddressValidation = func(net.Addr) bool { return false } } - - config = populateConfig(config, defaultConnIdLen) return config } -func populateConfig(config *Config, defaultConnIDLen int) *Config { +// populateConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil +func populateConfig(config *Config) *Config { if config == nil { config = &Config{} } @@ -62,9 +65,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } - if config.ConnectionIDLength == 0 { - config.ConnectionIDLength = defaultConnIDLen - } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -105,16 +105,15 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { if maxDatagrameFrameSize == 0 { maxDatagrameFrameSize = int64(protocol.DefaultMaxDatagramFrameSize) } - connIDGenerator := config.ConnectionIDGenerator - if connIDGenerator == nil { - connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: config.ConnectionIDLength} - } return &Config{ + GetConfigForClient: config.GetConfigForClient, Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - AcceptToken: config.AcceptToken, + MaxTokenAge: config.MaxTokenAge, + MaxRetryTokenAge: config.MaxRetryTokenAge, + RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, @@ -123,14 +122,12 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: config.ConnectionIDLength, - ConnectionIDGenerator: connIDGenerator, - StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, MaxDatagramFrameSize: maxDatagrameFrameSize, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, + Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, } } diff --git a/vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go b/vendor/github.com/quic-go/quic-go/conn_id_generator.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go rename to vendor/github.com/quic-go/quic-go/conn_id_generator.go index 15a95bab5fb..2d28dc619c3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/conn_id_generator.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_generator.go @@ -3,10 +3,10 @@ package quic import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type connIDGenerator struct { @@ -14,29 +14,26 @@ type connIDGenerator struct { highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID - initialClientDestConnID protocol.ConnectionID + initialClientDestConnID *protocol.ConnectionID // nil for the client addConnectionID func(protocol.ConnectionID) getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func(protocol.ConnectionID, packetHandler) + replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) queueControlFrame func(wire.Frame) - - version protocol.VersionNumber } func newConnIDGenerator( initialConnectionID protocol.ConnectionID, - initialClientDestConnID protocol.ConnectionID, // nil for the client + initialClientDestConnID *protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID), getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func(protocol.ConnectionID, packetHandler), + replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, - version protocol.VersionNumber, ) *connIDGenerator { m := &connIDGenerator{ generator: generator, @@ -47,7 +44,6 @@ func newConnIDGenerator( retireConnectionID: retireConnectionID, replaceWithClosed: replaceWithClosed, queueControlFrame: queueControlFrame, - version: version, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -64,7 +60,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { // transport parameter. // We currently don't send the preferred_address transport parameter, // so we can issue (limit - 1) connection IDs. - for i := uint64(len(m.activeSrcConnIDs)); i < utils.MinUint64(limit, protocol.MaxIssuedConnectionIDs); i++ { + for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ { if err := m.issueNewConnID(); err != nil { return err } @@ -84,7 +80,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect if !ok { return nil } - if connID.Equal(sentWithDestConnID) { + if connID == sentWithDestConnID { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), @@ -117,25 +113,27 @@ func (m *connIDGenerator) issueNewConnID() error { func (m *connIDGenerator) SetHandshakeComplete() { if m.initialClientDestConnID != nil { - m.retireConnectionID(m.initialClientDestConnID) + m.retireConnectionID(*m.initialClientDestConnID) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { - m.removeConnectionID(m.initialClientDestConnID) + m.removeConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { m.removeConnectionID(connID) } } -func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { +func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { + connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { - m.replaceWithClosed(m.initialClientDestConnID, handler) + connIDs = append(connIDs, *m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { - m.replaceWithClosed(connID, handler) + connIDs = append(connIDs, connID) } + m.replaceWithClosed(connIDs, pers, connClose) } diff --git a/vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go b/vendor/github.com/quic-go/quic-go/conn_id_manager.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go rename to vendor/github.com/quic-go/quic-go/conn_id_manager.go index e1b025a985f..ba65aec0438 100644 --- a/vendor/github.com/lucas-clemente/quic-go/conn_id_manager.go +++ b/vendor/github.com/quic-go/quic-go/conn_id_manager.go @@ -3,14 +3,21 @@ package quic import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + list "github.com/quic-go/quic-go/internal/utils/linkedlist" + "github.com/quic-go/quic-go/internal/wire" ) +type newConnID struct { + SequenceNumber uint64 + ConnectionID protocol.ConnectionID + StatelessResetToken protocol.StatelessResetToken +} + type connIDManager struct { - queue utils.NewConnectionIDList + queue list.List[newConnID] handshakeComplete bool activeSequenceNumber uint64 @@ -71,7 +78,7 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { // Retire elements in the queue. // Doesn't retire the active connection ID. if f.RetirePriorTo > h.highestRetired { - var next *utils.NewConnectionIDElement + var next *list.Element[newConnID] for el := h.queue.Front(); el != nil; el = next { if el.Value.SequenceNumber >= f.RetirePriorTo { break @@ -104,7 +111,7 @@ func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error { func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error { // insert a new element at the end if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq { - h.queue.PushBack(utils.NewConnectionID{ + h.queue.PushBack(newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, @@ -114,7 +121,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID // insert a new element somewhere in the middle for el := h.queue.Front(); el != nil; el = el.Next() { if el.Value.SequenceNumber == seq { - if !el.Value.ConnectionID.Equal(connID) { + if el.Value.ConnectionID != connID { return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) } if el.Value.StatelessResetToken != resetToken { @@ -123,7 +130,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID break } if el.Value.SequenceNumber > seq { - h.queue.InsertBefore(utils.NewConnectionID{ + h.queue.InsertBefore(newConnID{ SequenceNumber: seq, ConnectionID: connID, StatelessResetToken: resetToken, @@ -138,7 +145,7 @@ func (h *connIDManager) updateConnectionID() { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) - h.highestRetired = utils.MaxUint64(h.highestRetired, h.activeSequenceNumber) + h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber) if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } diff --git a/vendor/github.com/lucas-clemente/quic-go/connection.go b/vendor/github.com/quic-go/quic-go/connection.go similarity index 75% rename from vendor/github.com/lucas-clemente/quic-go/connection.go rename to vendor/github.com/quic-go/quic-go/connection.go index fa9bd1546fc..ae431804951 100644 --- a/vendor/github.com/lucas-clemente/quic-go/connection.go +++ b/vendor/github.com/quic-go/quic-go/connection.go @@ -13,19 +13,20 @@ import ( "sync/atomic" "time" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/logutils" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/logutils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" ) type unpacker interface { - Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) + UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) + UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } type streamGetter interface { @@ -95,7 +96,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed(protocol.ConnectionID, packetHandler) + ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -209,7 +210,7 @@ type connection struct { peerParams *wire.TransportParameters - timer *utils.Timer + timer connectionTimer // keepAlivePingSent stores whether a keep alive PING is in flight. // It is reset as soon as we receive a packet from the peer. keepAlivePingSent bool @@ -217,16 +218,18 @@ type connection struct { datagramQueue *datagramQueue + connStateMutex sync.Mutex + connState ConnectionState + logID string tracer logging.ConnectionTracer logger utils.Logger } var ( - _ Connection = &connection{} - _ EarlyConnection = &connection{} - _ streamSender = &connection{} - deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine + _ Connection = &connection{} + _ EarlyConnection = &connection{} + _ streamSender = &connection{} ) var newConnection = func( @@ -237,11 +240,12 @@ var newConnection = func( clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, statelessResetToken protocol.StatelessResetToken, conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, - enable0RTT bool, + clientAddressValidated bool, tracer logging.ConnectionTracer, tracingID uint64, logger utils.Logger, @@ -260,7 +264,7 @@ var newConnection = func( logger: logger, version: v, } - if origDestConnID != nil { + if origDestConnID.Len() > 0 { s.logID = origDestConnID.String() } else { s.logID = destConnID.String() @@ -273,15 +277,14 @@ var newConnection = func( ) s.connIDGenerator = newConnIDGenerator( srcConnID, - clientDestConnID, + &clientDestConnID, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, runner.GetStatelessResetToken, runner.Remove, runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, - s.version, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -289,10 +292,10 @@ var newConnection = func( 0, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, + clientAddressValidated, s.perspective, s.tracer, s.logger, - s.version, ) initialStream := newCryptoStream() handshakeStream := newCryptoStream() @@ -315,9 +318,8 @@ var newConnection = func( } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.ByteCount(s.config.MaxDatagramFrameSize) - if params.MaxDatagramFrameSize == 0 { - params.MaxDatagramFrameSize = protocol.DefaultMaxDatagramFrameSize - } + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.tracer != nil { s.tracer.SentTransportParameters(params) @@ -339,29 +341,15 @@ var newConnection = func( }, }, tlsConf, - enable0RTT, + conf.Allow0RTT, s.rttStats, tracer, logger, s.version, ) s.cryptoStreamHandler = cs - s.packer = newPacketPacker( - srcConnID, - s.connIDManager.Get, - initialStream, - handshakeStream, - s.sentPacketHandler, - s.retransmissionQueue, - s.RemoteAddr(), - cs, - s.framer, - s.receivedPacketHandler, - s.datagramQueue, - s.perspective, - s.version, - ) - s.unpacker = newPacketUnpacker(cs, s.version) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) return s } @@ -372,6 +360,7 @@ var newClientConnection = func( runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -411,8 +400,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, - s.version, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -420,10 +408,10 @@ var newClientConnection = func( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), s.rttStats, + false, /* has no effect */ s.perspective, s.tracer, s.logger, - s.version, ) initialStream := newCryptoStream() handshakeStream := newCryptoStream() @@ -443,6 +431,8 @@ var newClientConnection = func( } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.ByteCount(s.config.MaxDatagramFrameSize) + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.tracer != nil { s.tracer.SentTransportParameters(params) @@ -470,22 +460,8 @@ var newClientConnection = func( s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) - s.unpacker = newPacketUnpacker(cs, s.version) - s.packer = newPacketPacker( - srcConnID, - s.connIDManager.Get, - initialStream, - handshakeStream, - s.sentPacketHandler, - s.retransmissionQueue, - s.RemoteAddr(), - cs, - s.framer, - s.receivedPacketHandler, - s.datagramQueue, - s.perspective, - s.version, - ) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { s.tokenStoreKey = tlsConf.ServerName } else { @@ -501,8 +477,8 @@ var newClientConnection = func( func (s *connection) preSetup() { s.sendQueue = newSendQueue(s.conn) - s.retransmissionQueue = newRetransmissionQueue(s.version) - s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) + s.retransmissionQueue = newRetransmissionQueue() + s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(s.config.InitialConnectionReceiveWindow), @@ -524,9 +500,8 @@ func (s *connection) preSetup() { uint64(s.config.MaxIncomingStreams), uint64(s.config.MaxIncomingUniStreams), s.perspective, - s.version, ) - s.framer = newFramer(s.streamsMap, s.version) + s.framer = newFramer(s.streamsMap) s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) @@ -537,18 +512,21 @@ func (s *connection) preSetup() { s.creationTime = now s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) - if s.config.EnableDatagrams { - s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) - } + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + s.connState.Version = s.version } // run the connection main loop func (s *connection) run() error { defer s.ctxCancel() - s.timer = utils.NewTimer() + s.timer = *newTimer() - go s.cryptoStreamHandler.RunHandshake() + handshaking := make(chan struct{}) + go func() { + defer close(handshaking) + s.cryptoStreamHandler.RunHandshake() + }() go func() { if err := s.sendQueue.Run(); err != nil { s.destroyImpl(err) @@ -699,12 +677,13 @@ runLoop: } } + s.cryptoStreamHandler.Close() + <-handshaking s.handleCloseError(&closeErr) if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { s.tracer.Close() } s.logger.Infof("Connection %s closed.", s.logID) - s.cryptoStreamHandler.Close() s.sendQueue.Close() s.timer.Stop() return closeErr.err @@ -715,8 +694,8 @@ func (s *connection) earlyConnReady() <-chan struct{} { return s.earlyConnReadyChan } -func (s *connection) HandshakeComplete() context.Context { - return s.handshakeCtx +func (s *connection) HandshakeComplete() <-chan struct{} { + return s.handshakeCtx.Done() } func (s *connection) Context() context.Context { @@ -724,14 +703,14 @@ func (s *connection) Context() context.Context { } func (s *connection) supportsDatagrams() bool { - return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount + return s.peerParams.MaxDatagramFrameSize > 0 } func (s *connection) ConnectionState() ConnectionState { - return ConnectionState{ - TLS: s.cryptoStreamHandler.ConnectionState(), - SupportsDatagrams: s.supportsDatagrams(), - } + s.connStateMutex.Lock() + defer s.connStateMutex.Unlock() + s.connState.TLS = s.cryptoStreamHandler.ConnectionState() + return s.connState } // Time when the next keep-alive packet should be sent. @@ -758,17 +737,12 @@ func (s *connection) maybeResetTimer() { } } - if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() { - deadline = utils.MinTime(deadline, ackAlarm) - } - if lossTime := s.sentPacketHandler.GetLossDetectionTimeout(); !lossTime.IsZero() { - deadline = utils.MinTime(deadline, lossTime) - } - if !s.pacingDeadline.IsZero() { - deadline = utils.MinTime(deadline, s.pacingDeadline) - } - - s.timer.Reset(deadline) + s.timer.SetTimer( + deadline, + s.receivedPacketHandler.GetAlarmTimeout(), + s.sentPacketHandler.GetLossDetectionTimeout(), + s.pacingDeadline, + ) } func (s *connection) idleTimeoutStartTime() time.Time { @@ -821,7 +795,7 @@ func (s *connection) handleHandshakeConfirmed() { if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount } - maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize) + maxPacketSize = utils.Min(maxPacketSize, protocol.MaxPacketBufferSize) s.mtuDiscoverer = newMTUDiscoverer( s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), @@ -848,61 +822,133 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { data := rp.data p := rp for len(data) > 0 { + var destConnID protocol.ConnectionID if counter > 0 { p = p.Clone() p.data = data - } - hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen) - if err != nil { - if s.tracer != nil { - dropReason := logging.PacketDropHeaderParseError - if err == wire.ErrUnsupportedVersion { - dropReason = logging.PacketDropUnsupportedVersion + var err error + destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) + break + } + if destConnID != lastConnID { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + } + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) + break } - s.logger.Debugf("error parsing packet: %s", err) - break } - if hdr.IsLongHeader && hdr.Version != s.version { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + if wire.IsLongHeaderPacket(p.data[0]) { + hdr, packetData, rest, err := wire.ParsePacket(p.data) + if err != nil { + if s.tracer != nil { + dropReason := logging.PacketDropHeaderParseError + if err == wire.ErrUnsupportedVersion { + dropReason = logging.PacketDropUnsupportedVersion + } + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + } + s.logger.Debugf("error parsing packet: %s", err) + break } - s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) - break - } + lastConnID = hdr.DestConnectionID - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + if hdr.Version != s.version { + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + } + s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) + break + } + + if counter > 0 { + p.buffer.Split() } - s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + + p.data = packetData + + if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed { + processed = true + } + data = rest + } else { + if counter > 0 { + p.buffer.Split() + } + processed = s.handleShortHeaderPacket(p, destConnID) break } - lastConnID = hdr.DestConnectionID + } - if counter > 0 { - p.buffer.Split() - } - counter++ + p.buffer.MaybeRelease() + return processed +} + +func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { + var wasQueued bool - // only log if this actually a coalesced packet - if s.logger.Debug() && (counter > 1 || len(rest) > 0) { - s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + defer func() { + // Put back the packet buffer if the packet wasn't queued for later decryption. + if !wasQueued { + p.buffer.Decrement() } - p.data = packetData - if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { - processed = true + }() + + pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) + if err != nil { + wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT) + return false + } + + if s.logger.Debug() { + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID) + wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase) + } + + if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { + s.logger.Debugf("Dropping (potentially) duplicate packet.") + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) } - data = rest + return false } - p.buffer.MaybeRelease() - return processed + + var log func([]logging.Frame) + if s.tracer != nil { + log = func(frames []logging.Frame) { + s.tracer.ReceivedShortHeaderPacket( + &logging.ShortHeader{ + DestConnectionID: destConnID, + PacketNumber: pn, + PacketNumberLen: pnLen, + KeyPhase: keyPhase, + }, + p.Size(), + frames, + ) + } + } + if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil { + s.closeLocal(err) + return false + } + return true } -func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -918,7 +964,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } @@ -933,53 +979,18 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo return false } - packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data) + packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data, s.version) if err != nil { - switch err { - case handshake.ErrKeysDropped: - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable) - } - s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size()) - case handshake.ErrKeysNotYetAvailable: - // Sealer for this encryption level not yet available. - // Try again later. - wasQueued = true - s.tryQueueingUndecryptablePacket(p, hdr) - case wire.ErrInvalidReservedBits: - s.closeLocal(&qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: err.Error(), - }) - case handshake.ErrDecryptionFailed: - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) - default: - var headerErr *headerParseError - if errors.As(err, &headerErr) { - // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) - } - s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) - } else { - // This is an error returned by the AEAD (other than ErrDecryptionFailed). - // For example, a PROTOCOL_VIOLATION due to key updates. - s.closeLocal(err) - } - } + wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr)) return false } if s.logger.Debug() { - s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.packetNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.hdr.PacketNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel) packet.hdr.Log(s.logger) } - if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.packetNumber, packet.encryptionLevel) { + if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) @@ -987,13 +998,53 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo return false } - if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { + if err := s.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil { s.closeLocal(err) return false } return true } +func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { + switch err { + case handshake.ErrKeysDropped: + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) + } + s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) + case handshake.ErrKeysNotYetAvailable: + // Sealer for this encryption level not yet available. + // Try again later. + s.tryQueueingUndecryptablePacket(p, pt) + return true + case wire.ErrInvalidReservedBits: + s.closeLocal(&qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: err.Error(), + }) + case handshake.ErrDecryptionFailed: + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) + default: + var headerErr *headerParseError + if errors.As(err, &headerErr) { + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) + } else { + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + s.closeLocal(err) + } + } + return false +} + func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil { @@ -1010,7 +1061,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa return false } destConnID := s.connIDManager.Get() - if hdr.SrcConnectionID.Equal(destConnID) { + if hdr.SrcConnectionID == destConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } @@ -1065,7 +1116,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { return } - hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) + src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) @@ -1087,7 +1138,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) if s.tracer != nil { - s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) + s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) if !ok { @@ -1110,19 +1161,12 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { }) } -func (s *connection) handleUnpackedPacket( +func (s *connection) handleUnpackedLongHeaderPacket( packet *unpackedPacket, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount, // only for logging ) error { - if len(packet.data) == 0 { - return &qerr.TransportError{ - ErrorCode: qerr.ProtocolViolation, - ErrorMessage: "empty packet", - } - } - if !s.receivedFirstPacket { s.receivedFirstPacket = true if !s.versionNegotiated && s.tracer != nil { @@ -1136,7 +1180,7 @@ func (s *connection) handleUnpackedPacket( s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.perspective == protocol.PerspectiveClient && packet.hdr.SrcConnectionID != s.handshakeDestConnID { cid := packet.hdr.SrcConnectionID s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) s.handshakeDestConnID = cid @@ -1145,10 +1189,10 @@ func (s *connection) handleUnpackedPacket( // We create the connection as soon as we receive the first packet from the client. // We do that before authenticating the packet. // That means that if the source connection ID was corrupted, - // we might have create a connection with an incorrect source connection ID. + // we might have created a connection with an incorrect source connection ID. // Once we authenticate the first packet, we need to update it. if s.perspective == protocol.PerspectiveServer { - if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if packet.hdr.SrcConnectionID != s.handshakeDestConnID { s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } @@ -1167,16 +1211,53 @@ func (s *connection) handleUnpackedPacket( s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false + var log func([]logging.Frame) + if s.tracer != nil { + log = func(frames []logging.Frame) { + s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) + } + } + isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) + if err != nil { + return err + } + return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) +} + +func (s *connection) handleUnpackedShortHeaderPacket( + destConnID protocol.ConnectionID, + pn protocol.PacketNumber, + data []byte, + ecn protocol.ECN, + rcvTime time.Time, + log func([]logging.Frame), +) error { + s.lastPacketReceivedTime = rcvTime + s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} + s.keepAlivePingSent = false + + isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log) + if err != nil { + return err + } + return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) +} + +func (s *connection) handleFrames( + data []byte, + destConnID protocol.ConnectionID, + encLevel protocol.EncryptionLevel, + log func([]logging.Frame), +) (isAckEliciting bool, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. var frames []wire.Frame - r := bytes.NewReader(packet.data) - var isAckEliciting bool - for { - frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel) + for len(data) > 0 { + l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) if err != nil { - return err + return false, err } + data = data[l:] if frame == nil { break } @@ -1185,29 +1266,28 @@ func (s *connection) handleUnpackedPacket( } // Only process frames now if we're not logging. // If we're logging, we need to make sure that the packet_received event is logged first. - if s.tracer == nil { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err + if log == nil { + if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + return false, err } } else { frames = append(frames, frame) } } - if s.tracer != nil { + if log != nil { fs := make([]logging.Frame, len(frames)) for i, frame := range frames { fs[i] = logutils.ConvertFrame(frame) } - s.tracer.ReceivedPacket(packet.hdr, packetSize, fs) + log(fs) for _, frame := range frames { - if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil { - return err + if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + return false, err } } } - - return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) + return } func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { @@ -1220,6 +1300,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel err = s.handleStreamFrame(frame) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel) + wire.PutAckFrame(frame) case *wire.ConnectionCloseFrame: s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: @@ -1518,7 +1599,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) + s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) return } if closeErr.immediate { @@ -1535,8 +1616,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) - s.connIDGenerator.ReplaceWithClosed(cs) + s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { @@ -1566,6 +1646,9 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit) s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.streamsMap.UpdateLimits(params) + s.connStateMutex.Lock() + s.connState.SupportsDatagrams = s.supportsDatagrams() + s.connStateMutex.Unlock() } func (s *connection) handleTransportParameters(params *wire.TransportParameters) { @@ -1574,6 +1657,7 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters) ErrorCode: qerr.TransportParameterError, ErrorMessage: err.Error(), }) + return } s.peerParams = params // On the client side we have to wait for handshake completion. @@ -1584,6 +1668,10 @@ func (s *connection) handleTransportParameters(params *wire.TransportParameters) // the client's transport parameters. close(s.earlyConnReadyChan) } + + s.connStateMutex.Lock() + s.connState.SupportsDatagrams = s.supportsDatagrams() + s.connStateMutex.Unlock() } func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { @@ -1595,7 +1683,7 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) } // check the initial_source_connection_id - if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + if params.InitialSourceConnectionID != s.handshakeDestConnID { return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) } @@ -1603,14 +1691,14 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) return nil } // check the original_destination_connection_id - if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + if params.OriginalDestinationConnectionID != s.origDestConnID { return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) } if s.retrySrcConnID != nil { // a Retry was performed if params.RetrySourceConnectionID == nil { return errors.New("missing retry_source_connection_id") } - if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + if *params.RetrySourceConnectionID != *s.retrySrcConnID { return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) } } else if params.RetrySourceConnectionID != nil { @@ -1623,7 +1711,7 @@ func (s *connection) applyTransportParameters() { params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) - s.keepAliveInterval = utils.MinDuration(s.config.KeepAlivePeriod, utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) @@ -1672,7 +1760,7 @@ func (s *connection) sendPackets() error { } // We can at most send a single ACK only packet. // There will only be a new ACK after receiving new packets. - // SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer. return s.maybeSendAckOnlyPacket() case ackhandler.SendPTOInitial: if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { @@ -1707,27 +1795,41 @@ func (s *connection) sendPackets() error { } func (s *connection) maybeSendAckOnlyPacket() error { - packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) + if !s.handshakeConfirmed { + packet, err := s.packer.PackCoalescedPacket(true, s.version) + if err != nil { + return err + } + if packet == nil { + return nil + } + s.sendPackedCoalescedPacket(packet, time.Now()) + return nil + } + + now := time.Now() + p, buffer, err := s.packer.PackPacket(true, now, s.version) if err != nil { + if err == errNothingToPack { + return nil + } return err } - if packet == nil { - return nil - } - s.sendPackedPacket(packet, time.Now()) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return nil } func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { // Queue probe packets until we actually send out a packet, // or until there are no more packets to queue. - var packet *packedPacket + var packet *coalescedPacket for { if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued { break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) if err != nil { return err } @@ -1748,15 +1850,15 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { panic("unexpected encryption level") } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) if err != nil { return err } } - if packet == nil || packet.packetContents == nil { + if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - s.sendPackedPacket(packet, time.Now()) + s.sendPackedCoalescedPacket(packet, time.Now()) return nil } @@ -1768,44 +1870,59 @@ func (s *connection) sendPacket() (bool, error) { now := time.Now() if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket() + packet, err := s.packer.PackCoalescedPacket(false, s.version) if err != nil || packet == nil { return false, err } s.sentFirstPacket = true - s.logCoalescedPacket(packet) - for _, p := range packet.packets { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { - s.firstAckElicitingPacketAfterIdleSentTime = now - } - s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) - } - s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer) + s.sendPackedCoalescedPacket(packet, now) return true, nil - } - if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { - packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) + } else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { + ping, size := s.mtuDiscoverer.GetPing() + p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now, s.version) if err != nil { return false, err } - s.sendPackedPacket(packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } - packet, err := s.packer.PackPacket() - if err != nil || packet == nil { + p, buffer, err := s.packer.PackPacket(false, now, s.version) + if err != nil { + if err == errNothingToPack { + return false, nil + } return false, err } - s.sendPackedPacket(packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return true, nil } -func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { +func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) { s.firstAckElicitingPacketAfterIdleSentTime = now } - s.logPacket(packet) - s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue)) + + s.sentPacketHandler.SentPacket(p) + s.connIDManager.SentPacket() + s.sendQueue.Send(buffer) +} + +func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { + s.logCoalescedPacket(packet) + for _, p := range packet.longHdrPackets { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) + } + if p := packet.shortHdrPacket; p != nil { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.Packet) + } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer) } @@ -1816,14 +1933,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { var transportErr *qerr.TransportError var applicationErr *qerr.ApplicationError if errors.As(e, &transportErr) { - packet, err = s.packer.PackConnectionClose(transportErr) + packet, err = s.packer.PackConnectionClose(transportErr, s.version) } else if errors.As(e, &applicationErr) { - packet, err = s.packer.PackApplicationClose(applicationErr) + packet, err = s.packer.PackApplicationClose(applicationErr, s.version) } else { packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), - }) + }, s.version) } if err != nil { return nil, err @@ -1832,47 +1949,109 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return packet.buffer.Data, s.conn.Write(packet.buffer.Data) } -func (s *connection) logPacketContents(p *packetContents) { +func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { + // quic-go logging + if s.logger.Debug() { + p.header.Log(s.logger) + if p.ack != nil { + wire.LogFrame(s.logger, p.ack, true) + } + for _, frame := range p.frames { + wire.LogFrame(s.logger, frame.Frame, true) + } + } + // tracing if s.tracer != nil { frames := make([]logging.Frame, 0, len(p.frames)) for _, f := range p.frames { frames = append(frames, logutils.ConvertFrame(f.Frame)) } - s.tracer.SentPacket(p.header, p.length, p.ack, frames) + var ack *logging.AckFrame + if p.ack != nil { + ack = logutils.ConvertAckFrame(p.ack) + } + s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames) } +} - // quic-go logging - if !s.logger.Debug() { - return +func (s *connection) logShortHeaderPacket( + destConnID protocol.ConnectionID, + ackFrame *wire.AckFrame, + frames []*ackhandler.Frame, + pn protocol.PacketNumber, + pnLen protocol.PacketNumberLen, + kp protocol.KeyPhaseBit, + size protocol.ByteCount, + isCoalesced bool, +) { + if s.logger.Debug() && !isCoalesced { + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) } - p.header.Log(s.logger) - if p.ack != nil { - wire.LogFrame(s.logger, p.ack, true) + // quic-go logging + if s.logger.Debug() { + wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp) + if ackFrame != nil { + wire.LogFrame(s.logger, ackFrame, true) + } + for _, frame := range frames { + wire.LogFrame(s.logger, frame.Frame, true) + } } - for _, frame := range p.frames { - wire.LogFrame(s.logger, frame.Frame, true) + + // tracing + if s.tracer != nil { + fs := make([]logging.Frame, 0, len(frames)) + for _, f := range frames { + fs = append(fs, logutils.ConvertFrame(f.Frame)) + } + var ack *logging.AckFrame + if ackFrame != nil { + ack = logutils.ConvertAckFrame(ackFrame) + } + s.tracer.SentShortHeaderPacket( + &logging.ShortHeader{ + DestConnectionID: destConnID, + PacketNumber: pn, + PacketNumberLen: pnLen, + KeyPhase: kp, + }, + size, + ack, + fs, + ) } } func (s *connection) logCoalescedPacket(packet *coalescedPacket) { if s.logger.Debug() { - if len(packet.packets) > 1 { - s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) + // There's a short period between dropping both Initial and Handshake keys and completion of the handshake, + // during which we might call PackCoalescedPacket but just pack a short header packet. + if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil { + s.logShortHeaderPacket( + packet.shortHdrPacket.DestConnID, + packet.shortHdrPacket.Ack, + packet.shortHdrPacket.Frames, + packet.shortHdrPacket.PacketNumber, + packet.shortHdrPacket.PacketNumberLen, + packet.shortHdrPacket.KeyPhase, + packet.shortHdrPacket.Length, + false, + ) + return + } + if len(packet.longHdrPackets) > 1 { + s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID) } else { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel()) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel()) } } - for _, p := range packet.packets { - s.logPacketContents(p) + for _, p := range packet.longHdrPackets { + s.logLongHeaderPacket(p) } -} - -func (s *connection) logPacket(packet *packedPacket) { - if s.logger.Debug() { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) + if p := packet.shortHdrPacket; p != nil { + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) } - s.logPacketContents(packet.packetContents) } // AcceptStream returns the next stream openend by the peer @@ -1930,20 +2109,22 @@ func (s *connection) scheduleSending() { } } -func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { +// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. +// The logging.PacketType is only used for logging purposes. +func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { if s.handshakeComplete { panic("shouldn't queue undecryptable packets after handshake completion") } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if s.tracer != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDOSPrevention) + s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) if s.tracer != nil { - s.tracer.BufferedPacket(logging.PacketTypeFromHeader(hdr)) + s.tracer.BufferedPacket(pt, p.Size()) } s.undecryptablePackets = append(s.undecryptablePackets, p) } @@ -1975,6 +2156,10 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } func (s *connection) SendMessage(p []byte) error { + if !s.supportsDatagrams() { + return errors.New("datagram support disabled") + } + f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { return errors.New("message too large") @@ -1985,6 +2170,9 @@ func (s *connection) SendMessage(p []byte) error { } func (s *connection) ReceiveMessage() ([]byte, error) { + if !s.config.EnableDatagrams { + return nil, errors.New("datagram support disabled") + } return s.datagramQueue.Receive() } @@ -2005,7 +2193,7 @@ func (s *connection) GetVersion() protocol.VersionNumber { } func (s *connection) NextConnection() Connection { - <-s.HandshakeComplete().Done() + <-s.HandshakeComplete() s.streamsMap.UseResetMaps() return s } diff --git a/vendor/github.com/quic-go/quic-go/connection_timer.go b/vendor/github.com/quic-go/quic-go/connection_timer.go new file mode 100644 index 00000000000..171fdd01384 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/connection_timer.go @@ -0,0 +1,51 @@ +package quic + +import ( + "time" + + "github.com/quic-go/quic-go/internal/utils" +) + +var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine + +type connectionTimer struct { + timer *utils.Timer + last time.Time +} + +func newTimer() *connectionTimer { + return &connectionTimer{timer: utils.NewTimer()} +} + +func (t *connectionTimer) SetRead() { + if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately { + t.last = deadline + } + t.timer.SetRead() +} + +func (t *connectionTimer) Chan() <-chan time.Time { + return t.timer.Chan() +} + +// SetTimer resets the timer. +// It makes sure that the deadline is strictly increasing. +// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet. +// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately. +func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) { + deadline := idleTimeoutOrKeepAlive + if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) { + deadline = ackAlarm + } + if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) { + deadline = lossTime + } + if !pacing.IsZero() && pacing.Before(deadline) { + deadline = pacing + } + t.timer.Reset(deadline) +} + +func (t *connectionTimer) Stop() { + t.timer.Stop() +} diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go b/vendor/github.com/quic-go/quic-go/crypto_stream.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/crypto_stream.go rename to vendor/github.com/quic-go/quic-go/crypto_stream.go index 36e21d3301e..f10e91202fa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto_stream.go +++ b/vendor/github.com/quic-go/quic-go/crypto_stream.go @@ -4,10 +4,10 @@ import ( "fmt" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type cryptoStream interface { @@ -56,7 +56,7 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // could e.g. be a retransmission return nil } - s.highestOffset = utils.MaxByteCount(s.highestOffset, highestOffset) + s.highestOffset = utils.Max(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { return err } @@ -107,7 +107,7 @@ func (s *cryptoStreamImpl) HasData() bool { func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: s.writeOffset} - n := utils.MinByteCount(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) f.Data = s.writeBuf[:n] s.writeBuf = s.writeBuf[n:] s.writeOffset += n diff --git a/vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go b/vendor/github.com/quic-go/quic-go/crypto_stream_manager.go similarity index 93% rename from vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go rename to vendor/github.com/quic-go/quic-go/crypto_stream_manager.go index 66f9004905e..91946acfa52 100644 --- a/vendor/github.com/lucas-clemente/quic-go/crypto_stream_manager.go +++ b/vendor/github.com/quic-go/quic-go/crypto_stream_manager.go @@ -3,8 +3,8 @@ package quic import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) type cryptoDataHandler interface { diff --git a/vendor/github.com/quic-go/quic-go/datagram_queue.go b/vendor/github.com/quic-go/quic-go/datagram_queue.go new file mode 100644 index 00000000000..59c7d069b0f --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/datagram_queue.go @@ -0,0 +1,123 @@ +package quic + +import ( + "sync" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" +) + +type datagramQueue struct { + sendQueue chan *wire.DatagramFrame + nextFrame *wire.DatagramFrame + + rcvMx sync.Mutex + rcvQueue [][]byte + rcvd chan struct{} // used to notify Receive that a new datagram was received + + closeErr error + closed chan struct{} + + hasData func() + + dequeued chan struct{} + + logger utils.Logger +} + +func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { + return &datagramQueue{ + hasData: hasData, + sendQueue: make(chan *wire.DatagramFrame, 1), + rcvd: make(chan struct{}, 1), + dequeued: make(chan struct{}), + closed: make(chan struct{}), + logger: logger, + } +} + +// AddAndWait queues a new DATAGRAM frame for sending. +// It blocks until the frame has been dequeued. +func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { + select { + case h.sendQueue <- f: + h.hasData() + case <-h.closed: + return h.closeErr + } + + select { + case <-h.dequeued: + return nil + case <-h.closed: + return h.closeErr + } +} + +// Peek gets the next DATAGRAM frame for sending. +// If actually sent out, Pop needs to be called before the next call to Peek. +func (h *datagramQueue) Peek() *wire.DatagramFrame { + if h.nextFrame != nil { + return h.nextFrame + } + select { + case h.nextFrame = <-h.sendQueue: + h.dequeued <- struct{}{} + default: + return nil + } + return h.nextFrame +} + +func (h *datagramQueue) Pop() { + if h.nextFrame == nil { + panic("datagramQueue BUG: Pop called for nil frame") + } + h.nextFrame = nil +} + +// HandleDatagramFrame handles a received DATAGRAM frame. +func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { + data := make([]byte, len(f.Data)) + copy(data, f.Data) + var queued bool + h.rcvMx.Lock() + if len(h.rcvQueue) < protocol.DatagramRcvQueueLen { + h.rcvQueue = append(h.rcvQueue, data) + queued = true + select { + case h.rcvd <- struct{}{}: + default: + } + } + h.rcvMx.Unlock() + if !queued && h.logger.Debug() { + h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + } +} + +// Receive gets a received DATAGRAM frame. +func (h *datagramQueue) Receive() ([]byte, error) { + for { + h.rcvMx.Lock() + if len(h.rcvQueue) > 0 { + data := h.rcvQueue[0] + h.rcvQueue = h.rcvQueue[1:] + h.rcvMx.Unlock() + return data, nil + } + h.rcvMx.Unlock() + select { + case <-h.rcvd: + continue + case <-h.closed: + return nil, h.closeErr + } + } +} + +func (h *datagramQueue) CloseWithError(e error) { + h.closeErr = e + close(h.closed) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/errors.go b/vendor/github.com/quic-go/quic-go/errors.go similarity index 89% rename from vendor/github.com/lucas-clemente/quic-go/errors.go rename to vendor/github.com/quic-go/quic-go/errors.go index 0c9f0004e54..c9fb0a07b08 100644 --- a/vendor/github.com/lucas-clemente/quic-go/errors.go +++ b/vendor/github.com/quic-go/quic-go/errors.go @@ -3,7 +3,7 @@ package quic import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/qerr" ) type ( @@ -46,6 +46,7 @@ const ( type StreamError struct { StreamID StreamID ErrorCode StreamErrorCode + Remote bool } func (e *StreamError) Is(target error) bool { @@ -54,5 +55,9 @@ func (e *StreamError) Is(target error) bool { } func (e *StreamError) Error() string { - return fmt.Sprintf("stream %d canceled with error code %d", e.StreamID, e.ErrorCode) + pers := "local" + if e.Remote { + pers = "remote" + } + return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } diff --git a/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go b/vendor/github.com/quic-go/quic-go/frame_sorter.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/frame_sorter.go rename to vendor/github.com/quic-go/quic-go/frame_sorter.go index aeafa7d4209..bee0abadb53 100644 --- a/vendor/github.com/lucas-clemente/quic-go/frame_sorter.go +++ b/vendor/github.com/quic-go/quic-go/frame_sorter.go @@ -2,11 +2,24 @@ package quic import ( "errors" + "sync" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) +// byteInterval is an interval from one ByteCount to the other +type byteInterval struct { + Start protocol.ByteCount + End protocol.ByteCount +} + +var byteIntervalElementPool sync.Pool + +func init() { + byteIntervalElementPool = *list.NewPool[byteInterval]() +} + type frameSorterEntry struct { Data []byte DoneCb func() @@ -15,17 +28,17 @@ type frameSorterEntry struct { type frameSorter struct { queue map[protocol.ByteCount]frameSorterEntry readPos protocol.ByteCount - gaps *utils.ByteIntervalList + gaps *list.List[byteInterval] } var errDuplicateStreamData = errors.New("duplicate stream data") func newFrameSorter() *frameSorter { s := frameSorter{ - gaps: utils.NewByteIntervalList(), + gaps: list.NewWithPool[byteInterval](&byteIntervalElementPool), queue: make(map[protocol.ByteCount]frameSorterEntry), } - s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount}) + s.gaps.PushFront(byteInterval{Start: 0, End: protocol.MaxByteCount}) return &s } @@ -118,7 +131,7 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() if !startGapEqualsEndGap { s.deleteConsecutive(startGapEnd) - var nextGap *utils.ByteIntervalElement + var nextGap *list.Element[byteInterval] for gap := startGapNext; gap.Value.End < endGapStart; gap = nextGap { nextGap = gap.Next() s.deleteConsecutive(gap.Value.End) @@ -140,7 +153,7 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() } else { if startGapEqualsEndGap && adjustedStartGapEnd { // The frame split the existing gap into two. - s.gaps.InsertAfter(utils.ByteInterval{Start: end, End: startGapEnd}, startGap) + s.gaps.InsertAfter(byteInterval{Start: end, End: startGapEnd}, startGap) } else if !startGapEqualsEndGap { endGap.Value.Start = end } @@ -164,7 +177,7 @@ func (s *frameSorter) push(data []byte, offset protocol.ByteCount, doneCb func() return nil } -func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { +func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*list.Element[byteInterval], bool) { for gap := s.gaps.Front(); gap != nil; gap = gap.Next() { if offset >= gap.Value.Start && offset <= gap.Value.End { return gap, true @@ -176,7 +189,7 @@ func (s *frameSorter) findStartGap(offset protocol.ByteCount) (*utils.ByteInterv panic("no gap found") } -func (s *frameSorter) findEndGap(startGap *utils.ByteIntervalElement, offset protocol.ByteCount) (*utils.ByteIntervalElement, bool) { +func (s *frameSorter) findEndGap(startGap *list.Element[byteInterval], offset protocol.ByteCount) (*list.Element[byteInterval], bool) { for gap := startGap; gap != nil; gap = gap.Next() { if offset >= gap.Value.Start && offset < gap.Value.End { return gap, true diff --git a/vendor/github.com/lucas-clemente/quic-go/framer.go b/vendor/github.com/quic-go/quic-go/framer.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/framer.go rename to vendor/github.com/quic-go/quic-go/framer.go index 29d36b85c2a..0b2059164ab 100644 --- a/vendor/github.com/lucas-clemente/quic-go/framer.go +++ b/vendor/github.com/quic-go/quic-go/framer.go @@ -4,20 +4,20 @@ import ( "errors" "sync" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/quicvarint" ) type framer interface { HasData() bool QueueControlFrame(wire.Frame) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]ackhandler.Frame, protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) Handle0RTTRejection() error } @@ -26,7 +26,6 @@ type framerI struct { mutex sync.Mutex streamGetter streamGetter - version protocol.VersionNumber activeStreams map[protocol.StreamID]struct{} streamQueue []protocol.StreamID @@ -37,14 +36,10 @@ type framerI struct { var _ framer = &framerI{} -func newFramer( - streamGetter streamGetter, - v protocol.VersionNumber, -) framer { +func newFramer(streamGetter streamGetter) framer { return &framerI{ streamGetter: streamGetter, activeStreams: make(map[protocol.StreamID]struct{}), - version: v, } } @@ -67,16 +62,18 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Unlock() } -func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { +func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount f.controlFrameMutex.Lock() for len(f.controlFrames) > 0 { frame := f.controlFrames[len(f.controlFrames)-1] - frameLen := frame.Length(f.version) + frameLen := frame.Length(v) if length+frameLen > maxLen { break } - frames = append(frames, ackhandler.Frame{Frame: frame}) + af := ackhandler.GetFrame() + af.Frame = frame + frames = append(frames, af) length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } @@ -93,7 +90,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount var lastFrame *ackhandler.Frame f.mutex.Lock() @@ -118,7 +115,7 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol. // Therefore, we can pretend to have more bytes available when popping // the STREAM frame (which will always have the DataLen set). remainingLen += quicvarint.Len(uint64(remainingLen)) - frame, hasMoreData := str.popStreamFrame(remainingLen) + frame, hasMoreData := str.popStreamFrame(remainingLen, v) if hasMoreData { // put the stream back in the queue (at the end) f.streamQueue = append(f.streamQueue, id) } else { // no more data to send. Stream is not active any more @@ -130,16 +127,16 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.Frame, maxLen protocol. if frame == nil { continue } - frames = append(frames, *frame) - length += frame.Length(f.version) + frames = append(frames, frame) + length += frame.Length(v) lastFrame = frame } f.mutex.Unlock() if lastFrame != nil { - lastFrameLen := lastFrame.Length(f.version) + lastFrameLen := lastFrame.Length(v) // account for the smaller size of the last STREAM frame lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false - length += lastFrame.Length(f.version) - lastFrameLen + length += lastFrame.Length(v) - lastFrameLen } return frames, length } diff --git a/vendor/github.com/lucas-clemente/quic-go/interface.go b/vendor/github.com/quic-go/quic-go/interface.go similarity index 78% rename from vendor/github.com/lucas-clemente/quic-go/interface.go rename to vendor/github.com/quic-go/quic-go/interface.go index 4ca98609acf..ed6af8ae179 100644 --- a/vendor/github.com/lucas-clemente/quic-go/interface.go +++ b/vendor/github.com/quic-go/quic-go/interface.go @@ -7,9 +7,9 @@ import ( "net" "time" - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/logging" ) // The StreamID is the ID of a QUIC stream. @@ -26,16 +26,6 @@ const ( Version2 = protocol.Version2 ) -// A Token can be used to verify the ownership of the client address. -type Token struct { - // IsRetryToken encodes how the client received the token. There are two ways: - // * In a Retry packet sent when trying to establish a new connection. - // * In a NEW_TOKEN frame on a previous connection. - IsRetryToken bool - RemoteAddr string - SentTime time.Time -} - // A ClientToken is a token received by the client. // It can be used to skip address validation on future connection attempts. type ClientToken struct { @@ -67,6 +57,10 @@ var ConnectionTracingKey = connTracingCtxKey{} type connTracingCtxKey struct{} +// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the +// context returned by tls.Config.ClientHelloInfo.Context. +var QUICVersionContextKey = handshake.QUICVersionContextKey + // Stream is the interface implemented by QUIC streams // In addition to the errors listed on the Connection, // calls to stream functions can return a StreamError if the stream is canceled. @@ -183,10 +177,9 @@ type Connection interface { // CloseWithError closes the connection with an error. // The error string will be sent to the peer. CloseWithError(ApplicationErrorCode, string) error - // The context is cancelled when the connection is closed. + // Context returns a context that is cancelled when the connection is closed. Context() context.Context // ConnectionState returns basic details about the QUIC connection. - // It blocks until the handshake completes. // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState @@ -204,42 +197,54 @@ type EarlyConnection interface { Connection // HandshakeComplete blocks until the handshake completes (or fails). - // Data sent before completion of the handshake is encrypted with 1-RTT keys. - // Note that the client's identity hasn't been verified yet. - HandshakeComplete() context.Context + // For the client, data sent before completion of the handshake is encrypted with 0-RTT keys. + // For the serfer, data sent before completion of the handshake is encrypted with 1-RTT keys, + // however the client's identity is only verified once the handshake completes. + HandshakeComplete() <-chan struct{} NextConnection() Connection } -// A ConnectionIDGenerator is a interface that allows clients to implement their own format +// StatelessResetKey is a key used to derive stateless reset tokens. +type StatelessResetKey [32]byte + +// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. +// It is not able to handle QUIC Connection IDs longer than 20 bytes, +// as they are allowed by RFC 8999. +type ConnectionID = protocol.ConnectionID + +// ConnectionIDFromBytes interprets b as a Connection ID. It panics if b is +// longer than 20 bytes. +func ConnectionIDFromBytes(b []byte) ConnectionID { + return protocol.ParseConnectionID(b) +} + +// A ConnectionIDGenerator is an interface that allows clients to implement their own format // for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets. +// +// Connection IDs generated by an implementation should always produce IDs of constant size. type ConnectionIDGenerator interface { - // GenerateConnectionID generates a new ConnectionID. - GenerateConnectionID() ([]byte, error) + // Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs. + GenerateConnectionID() (ConnectionID, error) // ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of // this interface. + // Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size + // connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID. + // 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections + // in the presence of a connection migration environment. ConnectionIDLen() int } // Config contains all configuration data needed for a QUIC server or client. type Config struct { + // GetConfigForClient is called for incoming connections. + // If the error is not nil, the connection attempt is refused. + GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber - // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. - // If not set, the interpretation depends on where the Config is used: - // If used for dialing an address, a 0 byte connection ID will be used. - // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. - // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. - ConnectionIDLength int - // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. - // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. - // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. - // Otherwise, if one is provided, then ConnectionIDLength is ignored. - ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. @@ -250,14 +255,18 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // AcceptToken determines if a Token is accepted. - // It is called with token = nil if the client didn't send a token. - // If not set, a default verification function is used: - // * it verifies that the address matches, and - // * if the token is a retry token, that it was issued within the last 5 seconds - // * else, that it was issued within the last 24 hours. - // This option is only valid for the server. - AcceptToken func(clientAddr net.Addr, token *Token) bool + // RequireAddressValidation determines if a QUIC Retry packet is sent. + // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. + // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. + // If not set, every client is forced to prove its remote address. + RequireAddressValidation func(net.Addr) bool + // MaxRetryTokenAge is the maximum age of a Retry token. + // If not set, it defaults to 5 seconds. Only valid for a server. + MaxRetryTokenAge time.Duration + // MaxTokenAge is the maximum age of the token presented during the handshake, + // for tokens that were issued on a previous connection. + // If not set, it defaults to 24 hours. Only valid for a server. + MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set @@ -285,7 +294,7 @@ type Config struct { // limit the memory usage. // To avoid deadlocks, it is not valid to call other functions on the connection or on streams // in this callback. - AllowConnectionWindowIncrease func(sess Connection, delta uint64) bool + AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. // Values above 2^60 are invalid. // If not set, it will default to 100. @@ -296,9 +305,6 @@ type Config struct { // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int64 - // The StatelessResetKey is used to generate stateless reset tokens. - // If no key is configured, sending of stateless resets is disabled. - StatelessResetKey []byte // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // every half of MaxIdleTimeout, whichever is smaller). @@ -311,36 +317,23 @@ type Config struct { // This can be useful if version information is exchanged out-of-band. // It has no effect for a client. DisableVersionNegotiationPackets bool - // See https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/. - // Datagrams will only be available when both peers enable datagram support. - EnableDatagrams bool + // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. + // Only valid for the server. + Allow0RTT bool + // Enable QUIC datagram support (RFC 9221). + EnableDatagrams bool + // Maximum size of QUIC datagram frames (RFC 9221). MaxDatagramFrameSize int64 - Tracer logging.Tracer + Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer +} + +type ClientHelloInfo struct { + RemoteAddr net.Addr } // ConnectionState records basic details about a QUIC connection type ConnectionState struct { TLS handshake.ConnectionState SupportsDatagrams bool -} - -// A Listener for incoming QUIC connections -type Listener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new connections. It should be called in a loop. - Accept(context.Context) (Connection, error) -} - -// An EarlyListener listens for incoming QUIC connections, -// and returns them before the handshake completes. -type EarlyListener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new early connections. It should be called in a loop. - Accept(context.Context) (EarlyConnection, error) + Version VersionNumber } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go similarity index 80% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go index b8cd558ae64..4bab419013a 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/ack_eliciting.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go @@ -1,6 +1,6 @@ package ackhandler -import "github.com/lucas-clemente/quic-go/internal/wire" +import "github.com/quic-go/quic-go/internal/wire" // IsFrameAckEliciting returns true if the frame is ack-eliciting. func IsFrameAckEliciting(f wire.Frame) bool { @@ -10,7 +10,7 @@ func IsFrameAckEliciting(f wire.Frame) bool { } // HasAckElicitingFrames returns true if at least one frame is ack-eliciting. -func HasAckElicitingFrames(fs []Frame) bool { +func HasAckElicitingFrames(fs []*Frame) bool { for _, f := range fs { if IsFrameAckEliciting(f.Frame) { return true diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go new file mode 100644 index 00000000000..2c7cc4fcf0b --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ackhandler.go @@ -0,0 +1,23 @@ +package ackhandler + +import ( + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler. +// clientAddressValidated indicates whether the address was validated beforehand by an address validation token. +// clientAddressValidated has no effect for a client. +func NewAckHandler( + initialPacketNumber protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + clientAddressValidated bool, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, +) (SentPacketHandler, ReceivedPacketHandler) { + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + return sph, newReceivedPacketHandler(sph, rttStats, logger) +} diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go new file mode 100644 index 00000000000..deb23cfcb18 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go @@ -0,0 +1,29 @@ +package ackhandler + +import ( + "sync" + + "github.com/quic-go/quic-go/internal/wire" +) + +type Frame struct { + wire.Frame // nil if the frame has already been acknowledged in another packet + OnLost func(wire.Frame) + OnAcked func(wire.Frame) +} + +var framePool = sync.Pool{New: func() any { return &Frame{} }} + +func GetFrame() *Frame { + f := framePool.Get().(*Frame) + f.OnLost = nil + f.OnAcked = nil + return f +} + +func putFrame(f *Frame) { + f.Frame = nil + f.OnLost = nil + f.OnAcked = nil + framePool.Put(f) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go similarity index 71% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go index 226bfcbbcc5..5924f84bdaa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go @@ -3,30 +3,10 @@ package ackhandler import ( "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) -// A Packet is a packet -type Packet struct { - PacketNumber protocol.PacketNumber - Frames []Frame - LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK - Length protocol.ByteCount - EncryptionLevel protocol.EncryptionLevel - SendTime time.Time - - IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. - - includedInBytesInFlight bool - declaredLost bool - skippedPacket bool -} - -func (p *Packet) outstanding() bool { - return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket -} - // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go new file mode 100644 index 00000000000..d61783671bf --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/mockgen.go @@ -0,0 +1,6 @@ +//go:build gomock || generate + +package ackhandler + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" +type SentPacketTracker = sentPacketTracker diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go new file mode 100644 index 00000000000..394ee40a984 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go @@ -0,0 +1,55 @@ +package ackhandler + +import ( + "sync" + "time" + + "github.com/quic-go/quic-go/internal/protocol" +) + +// A Packet is a packet +type Packet struct { + PacketNumber protocol.PacketNumber + Frames []*Frame + LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK + Length protocol.ByteCount + EncryptionLevel protocol.EncryptionLevel + SendTime time.Time + + IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. + + includedInBytesInFlight bool + declaredLost bool + skippedPacket bool +} + +func (p *Packet) outstanding() bool { + return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket +} + +var packetPool = sync.Pool{New: func() any { return &Packet{} }} + +func GetPacket() *Packet { + p := packetPool.Get().(*Packet) + p.PacketNumber = 0 + p.Frames = nil + p.LargestAcked = 0 + p.Length = 0 + p.EncryptionLevel = protocol.EncryptionLevel(0) + p.SendTime = time.Time{} + p.IsPathMTUProbePacket = false + p.includedInBytesInFlight = false + p.declaredLost = false + p.skippedPacket = false + return p +} + +// We currently only return Packets back into the pool when they're acknowledged (not when they're lost). +// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. +func putPacket(p *Packet) { + for _, f := range p.Frames { + putFrame(f) + } + p.Frames = nil + packetPool.Put(p) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go index 7d58650cc8a..9cf20a0b004 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/packet_number_generator.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go @@ -1,8 +1,8 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) type packetNumberGenerator interface { @@ -72,5 +72,5 @@ func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) - p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod) + p.period = utils.Min(2*p.period, p.maxPeriod) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go similarity index 84% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go index 89fb30d3116..3675694f4ff 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_handler.go @@ -4,9 +4,9 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type receivedPacketHandler struct { @@ -25,13 +25,12 @@ func newReceivedPacketHandler( sentPackets sentPacketTracker, rttStats *utils.RTTStats, logger utils.Logger, - version protocol.VersionNumber, ) ReceivedPacketHandler { return &receivedPacketHandler{ sentPackets: sentPackets, - initialPackets: newReceivedPacketTracker(rttStats, logger, version), - handshakePackets: newReceivedPacketTracker(rttStats, logger, version), - appDataPackets: newReceivedPacketTracker(rttStats, logger, version), + initialPackets: newReceivedPacketTracker(rttStats, logger), + handshakePackets: newReceivedPacketTracker(rttStats, logger), + appDataPackets: newReceivedPacketTracker(rttStats, logger), lowest1RTTPacket: protocol.InvalidPacketNumber, } } @@ -46,24 +45,26 @@ func (h *receivedPacketHandler) ReceivedPacket( h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: - h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.EncryptionHandshake: - h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.Encryption0RTT: if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) } - h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) case protocol.Encryption1RTT: if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { h.lowest1RTTPacket = pn } + if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil { + return err + } h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) - h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) + return nil default: panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel)) } - return nil } func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go index 5a0391ef2c0..3143bfe1204 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_history.go @@ -1,23 +1,37 @@ package ackhandler import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "sync" + + "github.com/quic-go/quic-go/internal/protocol" + list "github.com/quic-go/quic-go/internal/utils/linkedlist" + "github.com/quic-go/quic-go/internal/wire" ) +// interval is an interval from one PacketNumber to the other +type interval struct { + Start protocol.PacketNumber + End protocol.PacketNumber +} + +var intervalElementPool sync.Pool + +func init() { + intervalElementPool = *list.NewPool[interval]() +} + // The receivedPacketHistory stores if a packet number has already been received. // It generates ACK ranges which can be used to assemble an ACK frame. // It does not store packet contents. type receivedPacketHistory struct { - ranges *utils.PacketIntervalList + ranges *list.List[interval] deletedBelow protocol.PacketNumber } func newReceivedPacketHistory() *receivedPacketHistory { return &receivedPacketHistory{ - ranges: utils.NewPacketIntervalList(), + ranges: list.NewWithPool[interval](&intervalElementPool), } } @@ -34,7 +48,7 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ { if h.ranges.Len() == 0 { - h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) + h.ranges.PushBack(interval{Start: p, End: p}) return true } @@ -61,13 +75,13 @@ func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is // create a new range at the end if p > el.Value.End { - h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) + h.ranges.InsertAfter(interval{Start: p, End: p}, el) return true } } // create a new range at the beginning - h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) + h.ranges.InsertBefore(interval{Start: p, End: p}, h.ranges.Front()) return true } @@ -101,17 +115,12 @@ func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) { } } -// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame -func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange { - if h.ranges.Len() == 0 { - return nil - } - - ackRanges := make([]wire.AckRange, h.ranges.Len()) - i := 0 - for el := h.ranges.Back(); el != nil; el = el.Prev() { - ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End} - i++ +// AppendAckRanges appends to a slice of all AckRanges that can be used in an AckFrame +func (h *receivedPacketHistory) AppendAckRanges(ackRanges []wire.AckRange) []wire.AckRange { + if h.ranges.Len() > 0 { + for el := h.ranges.Back(); el != nil; el = el.Prev() { + ackRanges = append(ackRanges, wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}) + } } return ackRanges } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go index 56e79269520..7132ccaade2 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go @@ -1,11 +1,12 @@ package ackhandler import ( + "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) // number of ack-eliciting packets received before sending an ack. @@ -30,27 +31,23 @@ type receivedPacketTracker struct { lastAck *wire.AckFrame logger utils.Logger - - version protocol.VersionNumber } func newReceivedPacketTracker( rttStats *utils.RTTStats, logger utils.Logger, - version protocol.VersionNumber, ) *receivedPacketTracker { return &receivedPacketTracker{ packetHistory: newReceivedPacketHistory(), maxAckDelay: protocol.MaxAckDelay, rttStats: rttStats, logger: logger, - version: version, } } -func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) { - if packetNumber < h.ignoreBelow { - return +func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error { + if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew { + return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber) } isMissing := h.isMissing(packetNumber) @@ -59,7 +56,7 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe h.largestObservedReceivedTime = rcvTime } - if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck { + if shouldInstigateAck { h.hasNewAck = true } if shouldInstigateAck { @@ -74,6 +71,7 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe case protocol.ECNCE: h.ecnce++ } + return nil } // IgnoreBelow sets a lower limit for acknowledging packets. @@ -171,16 +169,16 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { } } - ack := &wire.AckFrame{ - AckRanges: h.packetHistory.GetAckRanges(), - // Make sure that the DelayTime is always positive. - // This is not guaranteed on systems that don't have a monotonic clock. - DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)), - ECT0: h.ect0, - ECT1: h.ect1, - ECNCE: h.ecnce, - } + ack := wire.GetAckFrame() + ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) + ack.ECT0 = h.ect0 + ack.ECT1 = h.ect1 + ack.ECNCE = h.ecnce + ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) + if h.lastAck != nil { + wire.PutAckFrame(h.lastAck) + } h.lastAck = ack h.ackAlarm = time.Time{} h.ackQueued = false diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/send_mode.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go similarity index 91% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go index ff8cc8d2bdf..732bbc3a1d6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go @@ -5,12 +5,12 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/congestion" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/congestion" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" ) const ( @@ -23,6 +23,8 @@ const ( amplificationFactor = 3 // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. minRTTAfterRetry = 5 * time.Millisecond + // The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4. + maxPTODuration = 60 * time.Second ) type packetNumberSpace struct { @@ -101,10 +103,13 @@ var ( _ sentPacketTracker = &sentPacketHandler{} ) +// clientAddressValidated indicates whether the address was validated beforehand by an address validation token. +// If the address was validated, the amplification limit doesn't apply. It has no effect for a client. func newSentPacketHandler( initialPN protocol.PacketNumber, initialMaxDatagramSize protocol.ByteCount, rttStats *utils.RTTStats, + clientAddressValidated bool, pers protocol.Perspective, tracer logging.ConnectionTracer, logger utils.Logger, @@ -119,7 +124,7 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, - peerAddressValidated: pers == protocol.PerspectiveClient, + peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, initialPackets: newPacketNumberSpace(initialPN, false, rttStats), handshakePackets: newPacketNumberSpace(0, false, rttStats), appDataPackets: newPacketNumberSpace(0, true, rttStats), @@ -223,14 +228,20 @@ func (h *sentPacketHandler) packetsInFlight() int { return packetsInFlight } -func (h *sentPacketHandler) SentPacket(packet *Packet) { - h.bytesSent += packet.Length +func (h *sentPacketHandler) SentPacket(p *Packet) { + h.bytesSent += p.Length // For the client, drop the Initial packet number space when the first Handshake packet is sent. - if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { + if h.perspective == protocol.PerspectiveClient && p.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { h.dropPackets(protocol.EncryptionInitial) } - isAckEliciting := h.sentPacketImpl(packet) - h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet, isAckEliciting) + isAckEliciting := h.sentPacketImpl(p) + if isAckEliciting { + h.getPacketNumberSpace(p.EncryptionLevel).history.SentAckElicitingPacket(p) + } else { + h.getPacketNumberSpace(p.EncryptionLevel).history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) + putPacket(p) + p = nil //nolint:ineffassign // This is just to be on the safe side. + } if h.tracer != nil && isAckEliciting { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } @@ -256,7 +267,7 @@ func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-elicit pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { + for p := utils.Max(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { h.logger.Debugf("Skipping packet number %d", p) } } @@ -288,7 +299,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked) + pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) // Servers complete address validation when a protected packet is received. if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && @@ -310,7 +321,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // don't use the ack delay for Initial and Handshake packets var ackDelay time.Duration if encLevel == protocol.Encryption1RTT { - ackDelay = utils.MinDuration(ack.DelayTime, h.rttStats.MaxAckDelay()) + ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay()) } h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) if h.logger.Debug() { @@ -331,7 +342,11 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En acked1RTTPacket = true } h.removeFromBytesInFlight(p) + putPacket(p) } + // After this point, we must not use ackedPackets any longer! + // We've already returned the buffers. + ackedPackets = nil //nolint:ineffassign // This is just to be on the safe side. // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { @@ -406,7 +421,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL for _, p := range h.ackedPackets { if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1) + h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1) } for _, f := range p.Frames { @@ -444,6 +459,14 @@ func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.Encryptio return lossTime, encLevel } +func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration { + pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount + if pto > maxPTODuration || pto <= 0 { + return maxPTODuration + } + return pto +} + // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { // We only send application data probe packets once the handshake is confirmed, @@ -452,7 +475,7 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protoc if h.peerCompletedAddressValidation { return } - t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) + t := time.Now().Add(h.getScaledPTO(false)) if h.initialPackets != nil { return t, protocol.EncryptionInitial, true } @@ -462,18 +485,18 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protoc if h.initialPackets != nil { encLevel = protocol.EncryptionInitial if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { - pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) + pto = t.Add(h.getScaledPTO(false)) } } if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { - t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) + t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.EncryptionHandshake } } if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { - t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.Encryption1RTT @@ -554,11 +577,11 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} - maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) lossDelay := time.Duration(timeThreshold * maxRTT) // Minimum time of granularity before packets are deemed lost. - lossDelay = utils.MaxDuration(lossDelay, protocol.TimerGranularity) + lossDelay = utils.Max(lossDelay, protocol.TimerGranularity) // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) @@ -808,7 +831,7 @@ func (h *sentPacketHandler) ResetForRetry() error { if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. now := time.Now() - h.rttStats.UpdateRTT(utils.MaxDuration(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go similarity index 60% rename from vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go rename to vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go index d5704dd2c75..06478399143 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_history.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go @@ -2,55 +2,69 @@ package ackhandler import ( "fmt" + "sync" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) type sentPacketHistory struct { rttStats *utils.RTTStats - outstandingPacketList *PacketList - etcPacketList *PacketList - packetMap map[protocol.PacketNumber]*PacketElement + outstandingPacketList *list.List[*Packet] + etcPacketList *list.List[*Packet] + packetMap map[protocol.PacketNumber]*list.Element[*Packet] highestSent protocol.PacketNumber } +var packetElementPool sync.Pool + +func init() { + packetElementPool = *list.NewPool[*Packet]() +} + func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { return &sentPacketHistory{ rttStats: rttStats, - outstandingPacketList: NewPacketList(), - etcPacketList: NewPacketList(), - packetMap: make(map[protocol.PacketNumber]*PacketElement), + outstandingPacketList: list.NewWithPool[*Packet](&packetElementPool), + etcPacketList: list.NewWithPool[*Packet](&packetElementPool), + packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]), highestSent: protocol.InvalidPacketNumber, } } -func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { - if p.PacketNumber <= h.highestSent { +func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { + h.registerSentPacket(pn, encLevel, t) +} + +func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) { + h.registerSentPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) + + var el *list.Element[*Packet] + if p.outstanding() { + el = h.outstandingPacketList.PushBack(p) + } else { + el = h.etcPacketList.PushBack(p) + } + h.packetMap[p.PacketNumber] = el +} + +func (h *sentPacketHistory) registerSentPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { + if pn <= h.highestSent { panic("non-sequential packet number use") } // Skipped packet numbers. - for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { - el := h.etcPacketList.PushBack(Packet{ - PacketNumber: pn, - EncryptionLevel: p.EncryptionLevel, - SendTime: p.SendTime, + for p := h.highestSent + 1; p < pn; p++ { + el := h.etcPacketList.PushBack(&Packet{ + PacketNumber: p, + EncryptionLevel: encLevel, + SendTime: t, skippedPacket: true, }) - h.packetMap[pn] = el - } - h.highestSent = p.PacketNumber - - if isAckEliciting { - var el *PacketElement - if p.outstanding() { - el = h.outstandingPacketList.PushBack(*p) - } else { - el = h.etcPacketList.PushBack(*p) - } - h.packetMap[p.PacketNumber] = el + h.packetMap[p] = el } + h.highestSent = pn } // Iterate iterates through all packets. @@ -58,7 +72,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err cont := true outstandingEl := h.outstandingPacketList.Front() etcEl := h.etcPacketList.Front() - var el *PacketElement + var el *list.Element[*Packet] // whichever has the next packet number is returned first for cont { if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) { @@ -75,7 +89,7 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err etcEl = etcEl.Next() } var err error - cont, err = cb(&el.Value) + cont, err = cb(el.Value) if err != nil { return err } @@ -89,7 +103,7 @@ func (h *sentPacketHistory) FirstOutstanding() *Packet { if el == nil { return nil } - return &el.Value + return el.Value } func (h *sentPacketHistory) Len() int { @@ -101,8 +115,7 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if !ok { return fmt.Errorf("packet %d not found in sent packet history", p) } - h.outstandingPacketList.Remove(el) - h.etcPacketList.Remove(el) + el.List().Remove(el) delete(h.packetMap, p) return nil } @@ -113,7 +126,7 @@ func (h *sentPacketHistory) HasOutstandingPackets() bool { func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { maxAge := 3 * h.rttStats.PTO(false) - var nextEl *PacketElement + var nextEl *list.Element[*Packet] // we don't iterate outstandingPacketList, as we should not delete outstanding packets. // being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes. for el := h.etcPacketList.Front(); el != nil; el = nextEl { @@ -132,10 +145,7 @@ func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet { if !ok { return nil } - // try to remove it from both lists, as we don't know which one it currently belongs to. - // Remove is a no-op for elements that are not in the list. - h.outstandingPacketList.Remove(el) - h.etcPacketList.Remove(el) + el.List().Remove(el) p.declaredLost = true // move it to the correct position in the etc list (based on the packet number) for el = h.etcPacketList.Back(); el != nil; el = el.Prev() { @@ -144,10 +154,10 @@ func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet { } } if el == nil { - el = h.etcPacketList.PushFront(*p) + el = h.etcPacketList.PushFront(p) } else { - el = h.etcPacketList.InsertAfter(*p, el) + el = h.etcPacketList.InsertAfter(p, el) } h.packetMap[p.PacketNumber] = el - return &el.Value + return el.Value } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go b/vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go similarity index 91% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go index 96b1c5aa84a..1d03abbb8ad 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/bandwidth.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/bandwidth.go @@ -4,7 +4,7 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // Bandwidth of a connection diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go b/vendor/github.com/quic-go/quic-go/internal/congestion/clock.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/clock.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/clock.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go index beadd627cc3..a73cf82aa5e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic.go @@ -4,8 +4,8 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // This cubic implementation is based on the one found in Chromiums's QUIC @@ -187,7 +187,7 @@ func (c *Cubic) CongestionWindowAfterAck( targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) // Increase the window by approximately Alpha * 1 MSS of bytes every // time we ack an estimated tcp window of bytes. For small diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go index 059b8f6a505..dac3118e3d4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/cubic_sender.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go @@ -4,9 +4,9 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" ) const ( @@ -178,7 +178,7 @@ func (c *cubicSender) OnPacketAcked( priorInFlight protocol.ByteCount, eventTime time.Time, ) { - c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber) + c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { return } @@ -246,7 +246,7 @@ func (c *cubicSender) maybeIncreaseCwnd( c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go b/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go similarity index 91% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go index b5ae3d5eb10..b2f7c908ed1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/hybrid_slow_start.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/hybrid_slow_start.go @@ -3,8 +3,8 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // Note(pwestin): the magic clamping numbers come from the original code in @@ -75,8 +75,8 @@ func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT ti // Divide minRTT by 8 to get a rtt increase threshold for exiting. minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { s.hystartFound = true diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/interface.go index 5157383f39d..5db3ebae0c7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go @@ -3,7 +3,7 @@ package congestion import ( "time" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // A SendAlgorithm performs congestion control diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go similarity index 71% rename from vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go rename to vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go index 7ec4d8f5749..09ea268066d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/congestion/pacer.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go @@ -4,24 +4,24 @@ import ( "math" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) const maxBurstSizePackets = 10 // The pacer implements a token bucket pacing algorithm. type pacer struct { - budgetAtLastSent protocol.ByteCount - maxDatagramSize protocol.ByteCount - lastSentTime time.Time - getAdjustedBandwidth func() uint64 // in bytes/s + budgetAtLastSent protocol.ByteCount + maxDatagramSize protocol.ByteCount + lastSentTime time.Time + adjustedBandwidth func() uint64 // in bytes/s } func newPacer(getBandwidth func() Bandwidth) *pacer { p := &pacer{ maxDatagramSize: initialMaxDatagramSize, - getAdjustedBandwidth: func() uint64 { + adjustedBandwidth: func() uint64 { // Bandwidth is in bits/s. We need the value in bytes/s. bw := uint64(getBandwidth() / BytesPerSecond) // Use a slightly higher value than the actual measured bandwidth. @@ -49,13 +49,16 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount { if p.lastSentTime.IsZero() { return p.maxBurstSize() } - budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - return utils.MinByteCount(p.maxBurstSize(), budget) + budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = protocol.MaxByteCount + } + return utils.Min(p.maxBurstSize(), budget) } func (p *pacer) maxBurstSize() protocol.ByteCount { - return utils.MaxByteCount( - protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, + return utils.Max( + protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9, maxBurstSizePackets*p.maxDatagramSize, ) } @@ -66,9 +69,9 @@ func (p *pacer) TimeUntilSend() time.Time { if p.budgetAtLastSent >= p.maxDatagramSize { return time.Time{} } - return p.lastSentTime.Add(utils.MaxDuration( + return p.lastSentTime.Add(utils.Max( protocol.MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond, )) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go similarity index 93% rename from vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go rename to vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go index 2bf14fdc055..f3f24a60edd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/base_flow_controller.go @@ -4,8 +4,8 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) type baseFlowController struct { @@ -47,7 +47,7 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { c.bytesSent += n } -// UpdateSendWindow is be called after receiving a MAX_{STREAM_}DATA frame. +// UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame. func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { if offset > c.sendWindow { c.sendWindow = offset @@ -107,7 +107,7 @@ func (c *baseFlowController) maybeAdjustWindowSize() { now := time.Now() if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { // window is consumed too fast, try to increase the window size - newSize := utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize) + newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize) if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { c.receiveWindowSize = newSize } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go rename to vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go index 6bf2241b993..13e69d6c437 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go @@ -5,9 +5,9 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" ) type connectionFlowController struct { @@ -87,7 +87,7 @@ func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCoun c.mutex.Lock() if inc > c.receiveWindowSize { c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) - newSize := utils.MinByteCount(inc, c.maxReceiveWindowSize) + newSize := utils.Min(inc, c.maxReceiveWindowSize) if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { c.receiveWindowSize = newSize } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go rename to vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go index 1eeaee9fe97..946519d520b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/interface.go @@ -1,6 +1,6 @@ package flowcontrol -import "github.com/lucas-clemente/quic-go/internal/protocol" +import "github.com/quic-go/quic-go/internal/protocol" type flowController interface { // for sending diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go rename to vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go index aa66aef1a76..1770a9c8487 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go +++ b/vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go @@ -3,9 +3,9 @@ package flowcontrol import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" ) type streamFlowController struct { @@ -123,7 +123,7 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) + return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) } func (c *streamFlowController) shouldQueueWindowUpdate() bool { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/aead.go index 03b039289ce..410745f1a81 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/aead.go @@ -4,9 +4,9 @@ import ( "crypto/cipher" "encoding/binary" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/utils" ) func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { @@ -83,7 +83,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad [] // It uses the nonce provided here and XOR it with the IV. dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) if err == nil { - o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn) + o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn) } else { err = ErrDecryptionFailed } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go index 31d9bf0aa40..8c9c2a8f805 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go @@ -2,23 +2,29 @@ package handshake import ( "bytes" + "context" "crypto/tls" "errors" "fmt" "io" + "math" "net" "sync" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/qtls" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + "github.com/quic-go/quic-go/quicvarint" ) +type quicVersionContextKey struct{} + +var QUICVersionContextKey = &quicVersionContextKey{} + // TLS unexpected_message alert const alertUnexpectedMessage uint8 = 10 @@ -63,30 +69,25 @@ const clientSessionStateRevision = 3 type conn struct { localAddr, remoteAddr net.Addr - version protocol.VersionNumber } -var _ ConnWithVersion = &conn{} +var _ net.Conn = &conn{} -func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { +func newConn(local, remote net.Addr) net.Conn { return &conn{ localAddr: local, remoteAddr: remote, - version: version, } } -var _ net.Conn = &conn{} - -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } -func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } type cryptoSetup struct { tlsConf *tls.Config @@ -115,6 +116,7 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written zeroRTTParametersChan chan<- *wire.TransportParameters + allow0RTT bool rttStats *utils.RTTStats @@ -181,7 +183,7 @@ func NewCryptoSetupClient( protocol.PerspectiveClient, version, ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs, clientHelloWritten } @@ -195,7 +197,7 @@ func NewCryptoSetupServer( tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, - enable0RTT bool, + allow0RTT bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, @@ -208,14 +210,14 @@ func NewCryptoSetupServer( tp, runner, tlsConf, - enable0RTT, + allow0RTT, rttStats, tracer, logger, protocol.PerspectiveServer, version, ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs } @@ -250,6 +252,7 @@ func newCryptoSetup( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, + allow0RTT: enable0RTT, ourParams: tp, paramsChan: extHandler.TransportParameters(), rttStats: rttStats, @@ -260,14 +263,14 @@ func newCryptoSetup( alertChan: make(chan uint8), clientHelloWrittenChan: make(chan struct{}), zeroRTTParametersChan: zeroRTTParametersChan, - messageChan: make(chan []byte, 100), + messageChan: make(chan []byte, 1), isReadingHandshakeMessage: make(chan struct{}), closeChan: make(chan struct{}), version: version, } var maxEarlyData uint32 if enable0RTT { - maxEarlyData = 0xffffffff + maxEarlyData = math.MaxUint32 } cs.extraConf = &qtls.ExtraConfig{ GetExtensions: extHandler.GetExtensions, @@ -304,7 +307,7 @@ func (h *cryptoSetup) RunHandshake() { handshakeErrChan := make(chan error, 1) go func() { defer close(h.handshakeDone) - if err := h.conn.Handshake(); err != nil { + if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil { handshakeErrChan <- err return } @@ -340,7 +343,7 @@ func (h *cryptoSetup) onError(alert uint8, message string) { if alert == 0 { err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} } else { - err = qerr.NewCryptoError(alert, message) + err = qerr.NewLocalCryptoError(alert, message) } h.runner.OnError(err) } @@ -365,8 +368,15 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev h.onError(alertUnexpectedMessage, err.Error()) return false } - h.messageChan <- data + if encLevel != protocol.Encryption1RTT { + select { + case h.messageChan <- data: + case <-h.handshakeDone: // handshake errored, nobody is going to consume this message + return false + } + } if encLevel == protocol.Encryption1RTT { + h.messageChan <- data h.handlePostHandshakeMessage() return false } @@ -398,8 +408,7 @@ readLoop: func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { var expected protocol.EncryptionLevel switch msgType { - case typeClientHello, - typeServerHello: + case typeClientHello, typeServerHello: expected = protocol.EncryptionInitial case typeEncryptedExtensions, typeCertificate, @@ -432,11 +441,10 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) { // must be called after receiving the transport parameters func (h *cryptoSetup) marshalDataForSessionState() []byte { - buf := &bytes.Buffer{} - quicvarint.Write(buf, clientSessionStateRevision) - quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) - h.peerParams.MarshalForSessionTicket(buf) - return buf.Bytes() + b := make([]byte, 0, 256) + b = quicvarint.Append(b, clientSessionStateRevision) + b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) + return h.peerParams.MarshalForSessionTicket(b) } func (h *cryptoSetup) handleDataFromSessionState(data []byte) { @@ -491,13 +499,17 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) - if valid { - h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) - } else { + if !valid { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") + return false } - return valid + if !h.allow0RTT { + h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") + return false + } + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) + return true } // rejected0RTT is called for the client when the server rejects 0-RTT. @@ -576,7 +588,9 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph newHeaderProtector(suite, trafficSecret, true, h.version), ) h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + } if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) } @@ -589,12 +603,16 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph h.dropInitialKeys, h.perspective, ) - h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + } case qtls.EncryptionApplication: h.readEncLevel = protocol.Encryption1RTT h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true - h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) + } default: panic("unexpected read encryption level") } @@ -616,7 +634,9 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip newHeaderProtector(suite, trafficSecret, true, h.version), ) h.mutex.Unlock() - h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + } if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } @@ -629,12 +649,16 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip h.dropInitialKeys, h.perspective, ) - h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + } case qtls.EncryptionApplication: h.writeEncLevel = protocol.Encryption1RTT h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true - h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + if h.logger.Debug() { + h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) + } if h.zeroRTTSealer != nil { h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go b/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go index 1f800c50fec..274fb30cbd4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/header_protector.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/header_protector.go @@ -9,8 +9,8 @@ import ( "golang.org/x/crypto/chacha20" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" ) type headerProtector interface { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go b/vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/hkdf.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/hkdf.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go index 00ed243c75f..3967fdb83ab 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/initial_aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/initial_aead.go @@ -6,14 +6,14 @@ import ( "golang.org/x/crypto/hkdf" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" ) var ( quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} - quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} + quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9} ) const ( @@ -62,7 +62,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p } func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go b/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go similarity index 89% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/interface.go index 112f6c258b0..f80b6e0e36f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/interface.go @@ -3,12 +3,11 @@ package handshake import ( "errors" "io" - "net" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/wire" ) var ( @@ -93,10 +92,3 @@ type CryptoSetup interface { Get0RTTSealer() (LongHeaderSealer, error) Get1RTTSealer() (ShortHeaderSealer, error) } - -// ConnWithVersion is the connection used in the ClientHelloInfo. -// It can be used to determine the QUIC version in use. -type ConnWithVersion interface { - net.Conn - GetQUICVersion() protocol.VersionNumber -} diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go b/vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go new file mode 100644 index 00000000000..68b0988c6ac --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/mockgen.go @@ -0,0 +1,6 @@ +//go:build gomock || generate + +package handshake + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package handshake -destination mock_handshake_runner_test.go github.com/quic-go/quic-go/internal/handshake HandshakeRunner" +type HandshakeRunner = handshakeRunner diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go new file mode 100644 index 00000000000..ff14f7e0d24 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/retry.go @@ -0,0 +1,70 @@ +package handshake + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" + "sync" + + "github.com/quic-go/quic-go/internal/protocol" +) + +var ( + retryAEADdraft29 cipher.AEAD // used for QUIC draft versions up to 34 + retryAEADv1 cipher.AEAD // used for QUIC v1 (RFC 9000) + retryAEADv2 cipher.AEAD // used for QUIC v2 +) + +func init() { + retryAEADdraft29 = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}) + retryAEADv1 = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e}) + retryAEADv2 = initAEAD([16]byte{0x8f, 0xb4, 0xb0, 0x1b, 0x56, 0xac, 0x48, 0xe2, 0x60, 0xfb, 0xcb, 0xce, 0xad, 0x7c, 0xcc, 0x92}) +} + +func initAEAD(key [16]byte) cipher.AEAD { + aes, err := aes.NewCipher(key[:]) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + return aead +} + +var ( + retryBuf bytes.Buffer + retryMutex sync.Mutex + retryNonceDraft29 = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c} + retryNonceV1 = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb} + retryNonceV2 = [12]byte{0xd8, 0x69, 0x69, 0xbc, 0x2d, 0x7c, 0x6d, 0x99, 0x90, 0xef, 0xb0, 0x4a} +) + +// GetRetryIntegrityTag calculates the integrity tag on a Retry packet +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { + retryMutex.Lock() + defer retryMutex.Unlock() + + retryBuf.WriteByte(uint8(origDestConnID.Len())) + retryBuf.Write(origDestConnID.Bytes()) + retryBuf.Write(retry) + defer retryBuf.Reset() + + var tag [16]byte + var sealed []byte + //nolint:exhaustive // These are all the versions we support + switch version { + case protocol.Version1: + sealed = retryAEADv1.Seal(tag[:0], retryNonceV1[:], nil, retryBuf.Bytes()) + case protocol.Version2: + sealed = retryAEADv2.Seal(tag[:0], retryNonceV2[:], nil, retryBuf.Bytes()) + default: + sealed = retryAEADdraft29.Seal(tag[:0], retryNonceDraft29[:], nil, retryBuf.Bytes()) + } + if len(sealed) != 16 { + panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed))) + } + return &tag +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go index 75cc04f987d..56bcbcd5d06 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/session_ticket.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/session_ticket.go @@ -6,8 +6,8 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/quicvarint" ) const sessionTicketRevision = 2 @@ -18,11 +18,10 @@ type sessionTicket struct { } func (t *sessionTicket) Marshal() []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - quicvarint.Write(b, uint64(t.RTT.Microseconds())) - t.Parameters.MarshalForSessionTicket(b) - return b.Bytes() + b := make([]byte, 0, 256) + b = quicvarint.Append(b, sessionTicketRevision) + b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) + return t.Parameters.MarshalForSessionTicket(b) } func (t *sessionTicket) Unmarshal(b []byte) error { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go b/vendor/github.com/quic-go/quic-go/internal/handshake/tls_extension_handler.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/tls_extension_handler.go index 3a6790341ae..6105fe40109 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/tls_extension_handler.go @@ -1,8 +1,8 @@ package handshake import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" ) const ( @@ -24,7 +24,7 @@ var _ tlsExtensionHandler = &extensionHandler{} // newExtensionHandler creates a new extension handler func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler { et := uint16(quicTLSExtensionType) - if v != protocol.Version1 { + if v == protocol.VersionDraft29 { et = quicTLSExtensionTypeOldDrafts } return &extensionHandler{ diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go b/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go similarity index 76% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go index 2df5fcd8c3e..e5e90bb3ba8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_generator.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/token_generator.go @@ -1,13 +1,14 @@ package handshake import ( + "bytes" "encoding/asn1" "fmt" "io" "net" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) const ( @@ -17,14 +18,19 @@ const ( // A Token is derived from the client address and can be used to verify the ownership of this address. type Token struct { - IsRetryToken bool - RemoteAddr string - SentTime time.Time + IsRetryToken bool + SentTime time.Time + encodedRemoteAddr []byte // only set for retry tokens OriginalDestConnectionID protocol.ConnectionID RetrySrcConnectionID protocol.ConnectionID } +// ValidateRemoteAddr validates the address, but does not check expiration +func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { + return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) +} + // token is the struct that is used for ASN1 serialization and deserialization type token struct { IsRetryToken bool @@ -59,8 +65,8 @@ func (g *TokenGenerator) NewRetryToken( data, err := asn1.Marshal(token{ IsRetryToken: true, RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origDestConnID, - RetrySrcConnectionID: retrySrcConnID, + OriginalDestConnectionID: origDestConnID.Bytes(), + RetrySrcConnectionID: retrySrcConnID.Bytes(), Timestamp: time.Now().UnixNano(), }) if err != nil { @@ -101,13 +107,13 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } token := &Token{ - IsRetryToken: t.IsRetryToken, - RemoteAddr: decodeRemoteAddr(t.RemoteAddr), - SentTime: time.Unix(0, t.Timestamp), + IsRetryToken: t.IsRetryToken, + SentTime: time.Unix(0, t.Timestamp), + encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { - token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) - token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID) } return token, nil } @@ -119,16 +125,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte { } return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) } - -// decodeRemoteAddr decodes the remote address saved in the token -func decodeRemoteAddr(data []byte) string { - // data will never be empty for a token that we generated. - // Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == tokenPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go b/vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/token_protector.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/token_protector.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go b/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go rename to vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go index 1532e7b5a5b..ac01acdb1c1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/handshake/updatable_aead.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/updatable_aead.go @@ -8,17 +8,21 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/qtls" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" ) // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. // It's a package-level variable to allow modifying it for testing purposes. var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval +// FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. +// It's a package-level variable to allow modifying it for testing purposes. +var FirstKeyUpdateInterval uint64 = 100 + type updatableAEAD struct { suite *qtls.CipherSuiteTLS13 @@ -27,7 +31,6 @@ type updatableAEAD struct { firstPacketNumber protocol.PacketNumber handshakeConfirmed bool - keyUpdateInterval uint64 invalidPacketLimit uint64 invalidPacketCount uint64 @@ -74,7 +77,6 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, largestAcked: protocol.InvalidPacketNumber, firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, firstSentWithCurrentKey: protocol.InvalidPacketNumber, - keyUpdateInterval: KeyUpdateInterval, rttStats: rttStats, tracer: tracer, logger: logger, @@ -116,6 +118,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) } +// SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -129,6 +132,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } +// SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { @@ -169,7 +173,7 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } } if err == nil { - a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) + a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn) } return dec, err } @@ -284,11 +288,17 @@ func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { if !a.updateAllowed() { return false } - if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { + // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. + if a.keyPhase == 0 { + if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { + return true + } + } + if a.numRcvdWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) return true } - if a.numSentWithCurrentKey >= a.keyUpdateInterval { + if a.numSentWithCurrentKey >= KeyUpdateInterval { a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) return true } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go b/vendor/github.com/quic-go/quic-go/internal/logutils/frame.go similarity index 54% rename from vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go rename to vendor/github.com/quic-go/quic-go/internal/logutils/frame.go index 6e0fd311b9f..a6032fc20d7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/logutils/frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/logutils/frame.go @@ -1,9 +1,9 @@ package logutils import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go/logging" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" ) // ConvertFrame converts a wire.Frame into a logging.Frame. @@ -11,6 +11,10 @@ import ( // Furthermore, it removes the data slices from CRYPTO and STREAM frames. func ConvertFrame(frame wire.Frame) logging.Frame { switch f := frame.(type) { + case *wire.AckFrame: + // We use a pool for ACK frames. + // Implementations of the tracer interface may hold on to frames, so we need to make a copy here. + return ConvertAckFrame(f) case *wire.CryptoFrame: return &logging.CryptoFrame{ Offset: f.Offset, @@ -31,3 +35,16 @@ func ConvertFrame(frame wire.Frame) logging.Frame { return logging.Frame(frame) } } + +func ConvertAckFrame(f *wire.AckFrame) *logging.AckFrame { + ranges := make([]wire.AckRange, 0, len(f.AckRanges)) + ranges = append(ranges, f.AckRanges...) + ack := &logging.AckFrame{ + AckRanges: ranges, + DelayTime: f.DelayTime, + ECNCE: f.ECNCE, + ECT0: f.ECT0, + ECT1: f.ECT1, + } + return ack +} diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go b/vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go new file mode 100644 index 00000000000..77259b5fa58 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/connection_id.go @@ -0,0 +1,116 @@ +package protocol + +import ( + "crypto/rand" + "errors" + "fmt" + "io" +) + +var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length") + +// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999. +// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1 +// restricts the length to 20 bytes. +type ArbitraryLenConnectionID []byte + +func (c ArbitraryLenConnectionID) Len() int { + return len(c) +} + +func (c ArbitraryLenConnectionID) Bytes() []byte { + return c +} + +func (c ArbitraryLenConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} + +const maxConnectionIDLen = 20 + +// A ConnectionID in QUIC +type ConnectionID struct { + b [20]byte + l uint8 +} + +// GenerateConnectionID generates a connection ID using cryptographic random +func GenerateConnectionID(l int) (ConnectionID, error) { + var c ConnectionID + c.l = uint8(l) + _, err := rand.Read(c.b[:l]) + return c, err +} + +// ParseConnectionID interprets b as a Connection ID. +// It panics if b is longer than 20 bytes. +func ParseConnectionID(b []byte) ConnectionID { + if len(b) > maxConnectionIDLen { + panic("invalid conn id length") + } + var c ConnectionID + c.l = uint8(len(b)) + copy(c.b[:c.l], b) + return c +} + +// GenerateConnectionIDForInitial generates a connection ID for the Initial packet. +// It uses a length randomly chosen between 8 and 20 bytes. +func GenerateConnectionIDForInitial() (ConnectionID, error) { + r := make([]byte, 1) + if _, err := rand.Read(r); err != nil { + return ConnectionID{}, err + } + l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(l) +} + +// ReadConnectionID reads a connection ID of length len from the given io.Reader. +// It returns io.EOF if there are not enough bytes to read. +func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { + var c ConnectionID + if l == 0 { + return c, nil + } + if l > maxConnectionIDLen { + return c, ErrInvalidConnectionIDLen + } + c.l = uint8(l) + _, err := io.ReadFull(r, c.b[:l]) + if err == io.ErrUnexpectedEOF { + return c, io.EOF + } + return c, err +} + +// Len returns the length of the connection ID in bytes +func (c ConnectionID) Len() int { + return int(c.l) +} + +// Bytes returns the byte representation +func (c ConnectionID) Bytes() []byte { + return c.b[:c.l] +} + +func (c ConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} + +type DefaultConnectionIDGenerator struct { + ConnLen int +} + +func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { + return GenerateConnectionID(d.ConnLen) +} + +func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int { + return d.ConnLen +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go b/vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/encryption_level.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/encryption_level.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go b/vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/key_phase.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/key_phase.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go b/vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/packet_number.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go similarity index 98% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/params.go index 684e5dec891..970fe190a60 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/params.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go @@ -132,10 +132,9 @@ const MaxPostHandshakeCryptoFrameSize = 1000 // but must ensure that a maximum size ACK frame fits into one packet. const MaxAckFrameSize ByteCount = 1000 -// DefaultMaxDatagramFrameSize is the maximum size of a DATAGRAM frame as defined in -// https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +// DefaultMaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). // The size is chosen such that a DATAGRAM frame fits into a QUIC packet. -const DefaultMaxDatagramFrameSize ByteCount = 1220 +const DefaultMaxDatagramFrameSize ByteCount = 1200 // DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) const DatagramRcvQueueLen = 128 diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go b/vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/perspective.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go b/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go b/vendor/github.com/quic-go/quic-go/internal/protocol/stream.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/stream.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/stream.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go b/vendor/github.com/quic-go/quic-go/internal/protocol/version.go similarity index 81% rename from vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go rename to vendor/github.com/quic-go/quic-go/internal/protocol/version.go index dd54dbd3cb5..20e8976e363 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/version.go @@ -18,12 +18,10 @@ const ( // The version numbers, making grepping easier const ( - VersionTLS VersionNumber = 0x1 - VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter - VersionUnknown VersionNumber = math.MaxUint32 - VersionDraft29 VersionNumber = 0xff00001d - Version1 VersionNumber = 0x1 - Version2 VersionNumber = 0x709a50c4 + VersionUnknown VersionNumber = math.MaxUint32 + VersionDraft29 VersionNumber = 0xff00001d + Version1 VersionNumber = 0x1 + Version2 VersionNumber = 0x6b3343cf ) // SupportedVersions lists the versions that the server supports @@ -32,19 +30,12 @@ var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29} // IsValidVersion says if the version is known to quic-go func IsValidVersion(v VersionNumber) bool { - return v == VersionTLS || IsSupportedVersion(SupportedVersions, v) + return v == Version1 || IsSupportedVersion(SupportedVersions, v) } func (vn VersionNumber) String() string { - // For releases, VersionTLS will be set to a draft version. - // A switch statement can't contain duplicate cases. - if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != Version1 { - return "TLS dev version (WIP)" - } //nolint:exhaustive switch vn { - case VersionWhatever: - return "whatever" case VersionUnknown: return "unknown" case VersionDraft29: diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go b/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go rename to vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go index bee42d51fa2..cc846df6a77 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/error_codes.go +++ b/vendor/github.com/quic-go/quic-go/internal/qerr/error_codes.go @@ -3,7 +3,7 @@ package qerr import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/qtls" ) // TransportErrorCode is a QUIC transport error. @@ -81,7 +81,7 @@ func (e TransportErrorCode) String() string { return "NO_VIABLE_PATH" default: if e.IsCryptoError() { - return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e)) + return fmt.Sprintf("CRYPTO_ERROR %#x", uint16(e)) } return fmt.Sprintf("unknown error code: %#x", uint16(e)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go similarity index 84% rename from vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go rename to vendor/github.com/quic-go/quic-go/internal/qerr/errors.go index 8b1cff980e0..26ea344521b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qerr/errors.go +++ b/vendor/github.com/quic-go/quic-go/internal/qerr/errors.go @@ -4,7 +4,7 @@ import ( "fmt" "net" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) var ( @@ -21,8 +21,8 @@ type TransportError struct { var _ error = &TransportError{} -// NewCryptoError create a new TransportError instance for a crypto error -func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { +// NewLocalCryptoError create a new TransportError instance for a crypto error +func NewLocalCryptoError(tlsAlert uint8, errorMessage string) *TransportError { return &TransportError{ ErrorCode: 0x100 + TransportErrorCode(tlsAlert), ErrorMessage: errorMessage, @@ -30,7 +30,7 @@ func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError { } func (e *TransportError) Error() string { - str := e.ErrorCode.String() + str := fmt.Sprintf("%s (%s)", e.ErrorCode.String(), getRole(e.Remote)) if e.FrameType != 0 { str += fmt.Sprintf(" (frame type: %#x)", e.FrameType) } @@ -68,9 +68,9 @@ var _ error = &ApplicationError{} func (e *ApplicationError) Error() string { if len(e.ErrorMessage) == 0 { - return fmt.Sprintf("Application error %#x", e.ErrorCode) + return fmt.Sprintf("Application error %#x (%s)", e.ErrorCode, getRole(e.Remote)) } - return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage) + return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage) } type IdleTimeoutError struct{} @@ -122,3 +122,10 @@ func (e *StatelessResetError) Is(target error) bool { func (e *StatelessResetError) Timeout() bool { return false } func (e *StatelessResetError) Temporary() bool { return true } + +func getRole(remote bool) string { + if remote { + return "remote" + } + return "local" +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go119.go similarity index 58% rename from vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go rename to vendor/github.com/quic-go/quic-go/internal/qtls/go119.go index bc385f19431..f040b859c6e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go117.go +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/go119.go @@ -1,5 +1,4 @@ -//go:build go1.17 && !go1.18 -// +build go1.17,!go1.18 +//go:build go1.19 && !go1.20 package qtls @@ -7,10 +6,11 @@ import ( "crypto" "crypto/cipher" "crypto/tls" + "fmt" "net" "unsafe" - "github.com/marten-seemann/qtls-go1-17" + "github.com/quic-go/qtls-go1-19" ) type ( @@ -18,7 +18,7 @@ type ( Alert = qtls.Alert // A Certificate is qtls.Certificate. Certificate = qtls.Certificate - // CertificateRequestInfo contains inforamtion about a certificate request. + // CertificateRequestInfo contains information about a certificate request. CertificateRequestInfo = qtls.CertificateRequestInfo // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 CipherSuiteTLS13 = qtls.CipherSuiteTLS13 @@ -84,7 +84,7 @@ type cipherSuiteTLS13 struct { Hash crypto.Hash } -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID +//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 // CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. @@ -98,3 +98,48 @@ func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { Hash: cs.Hash, } } + +//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go120.go similarity index 59% rename from vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go rename to vendor/github.com/quic-go/quic-go/internal/qtls/go120.go index 86dcaea3dcb..a40146ab4fa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/qtls/go119.go +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/go120.go @@ -1,5 +1,4 @@ -//go:build go1.19 -// +build go1.19 +//go:build go1.20 package qtls @@ -7,10 +6,11 @@ import ( "crypto" "crypto/cipher" "crypto/tls" + "fmt" "net" "unsafe" - "github.com/marten-seemann/qtls-go1-19" + "github.com/quic-go/qtls-go1-20" ) type ( @@ -84,7 +84,7 @@ type cipherSuiteTLS13 struct { Hash crypto.Hash } -//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-19.cipherSuiteTLS13ByID +//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 // CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. @@ -98,3 +98,48 @@ func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { Hash: cs.Hash, } } + +//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-20.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/go121.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go121.go new file mode 100644 index 00000000000..b33406397b6 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/go121.go @@ -0,0 +1,5 @@ +//go:build go1.21 + +package qtls + +var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go b/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go new file mode 100644 index 00000000000..e15f03629a6 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/qtls/go_oldversion.go @@ -0,0 +1,5 @@ +//go:build !go1.19 + +package qtls + +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions." diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go b/vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/buffered_write_closer.go rename to vendor/github.com/quic-go/quic-go/internal/utils/buffered_write_closer.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go b/vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go similarity index 85% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go rename to vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go index d1f528429aa..a9b715e2f11 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/byteorder.go @@ -7,6 +7,10 @@ import ( // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers. type ByteOrder interface { + Uint32([]byte) uint32 + Uint24([]byte) uint32 + Uint16([]byte) uint16 + ReadUint32(io.ByteReader) (uint32, error) ReadUint24(io.ByteReader) (uint32, error) ReadUint16(io.ByteReader) (uint16, error) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go b/vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go similarity index 85% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go rename to vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go index d05542e1d48..834a711b9ef 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/byteorder_big_endian.go @@ -2,6 +2,7 @@ package utils import ( "bytes" + "encoding/binary" "io" ) @@ -73,6 +74,19 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) { return uint16(b1) + uint16(b2)<<8, nil } +func (bigEndian) Uint32(b []byte) uint32 { + return binary.BigEndian.Uint32(b) +} + +func (bigEndian) Uint24(b []byte) uint32 { + _ = b[2] // bounds check hint to compiler; see golang.org/issue/14808 + return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16 +} + +func (bigEndian) Uint16(b []byte) uint16 { + return binary.BigEndian.Uint16(b) +} + // WriteUint32 writes a uint32 func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) { b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)}) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go b/vendor/github.com/quic-go/quic-go/internal/utils/ip.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/ip.go rename to vendor/github.com/quic-go/quic-go/internal/utils/ip.go diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md b/vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md new file mode 100644 index 00000000000..66482f4fb1b --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/README.md @@ -0,0 +1,6 @@ +# Usage + +This is the Go standard library implementation of a linked list +(https://golang.org/src/container/list/list.go), with the following modifications: +* it uses Go generics +* it allows passing in a `sync.Pool` (via the `NewWithPool` constructor) to reduce allocations of `Element` structs diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go b/vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go similarity index 58% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go rename to vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go index 096023ef28d..804a34444a8 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/byteinterval_linkedlist.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/linkedlist/linkedlist.go @@ -1,29 +1,40 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -package utils +// Package list implements a doubly linked list. +// +// To iterate over a list (where l is a *List[T]): +// +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e.Value +// } +package list -// Linked list implementation from the Go standard library. +import "sync" -// ByteIntervalElement is an element of a linked list. -type ByteIntervalElement struct { +func NewPool[T any]() *sync.Pool { + return &sync.Pool{New: func() any { return &Element[T]{} }} +} + +// Element is an element of a linked list. +type Element[T any] struct { // Next and previous pointers in the doubly-linked list of elements. // To simplify the implementation, internally a list l is implemented // as a ring, such that &l.root is both the next element of the last // list element (l.Back()) and the previous element of the first list // element (l.Front()). - next, prev *ByteIntervalElement + next, prev *Element[T] // The list to which this element belongs. - list *ByteIntervalList + list *List[T] // The value stored with this element. - Value ByteInterval + Value T } // Next returns the next list element or nil. -func (e *ByteIntervalElement) Next() *ByteIntervalElement { +func (e *Element[T]) Next() *Element[T] { if p := e.next; e.list != nil && p != &e.list.root { return p } @@ -31,36 +42,49 @@ func (e *ByteIntervalElement) Next() *ByteIntervalElement { } // Prev returns the previous list element or nil. -func (e *ByteIntervalElement) Prev() *ByteIntervalElement { +func (e *Element[T]) Prev() *Element[T] { if p := e.prev; e.list != nil && p != &e.list.root { return p } return nil } -// ByteIntervalList is a linked list of ByteIntervals. -type ByteIntervalList struct { - root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used - len int // current list length excluding (this) sentinel element +func (e *Element[T]) List() *List[T] { + return e.list +} + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element + + pool *sync.Pool } // Init initializes or clears list l. -func (l *ByteIntervalList) Init() *ByteIntervalList { +func (l *List[T]) Init() *List[T] { l.root.next = &l.root l.root.prev = &l.root l.len = 0 return l } -// NewByteIntervalList returns an initialized list. -func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() } +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// NewWithPool returns an initialized list, using a sync.Pool for list elements. +func NewWithPool[T any](pool *sync.Pool) *List[T] { + l := &List[T]{pool: pool} + return l.Init() +} // Len returns the number of elements of list l. // The complexity is O(1). -func (l *ByteIntervalList) Len() int { return l.len } +func (l *List[T]) Len() int { return l.len } // Front returns the first element of list l or nil if the list is empty. -func (l *ByteIntervalList) Front() *ByteIntervalElement { +func (l *List[T]) Front() *Element[T] { if l.len == 0 { return nil } @@ -68,7 +92,7 @@ func (l *ByteIntervalList) Front() *ByteIntervalElement { } // Back returns the last element of list l or nil if the list is empty. -func (l *ByteIntervalList) Back() *ByteIntervalElement { +func (l *List[T]) Back() *Element[T] { if l.len == 0 { return nil } @@ -76,60 +100,83 @@ func (l *ByteIntervalList) Back() *ByteIntervalElement { } // lazyInit lazily initializes a zero List value. -func (l *ByteIntervalList) lazyInit() { +func (l *List[T]) lazyInit() { if l.root.next == nil { l.Init() } } // insert inserts e after at, increments l.len, and returns e. -func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement { - n := at.next - at.next = e +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { e.prev = at - e.next = n - n.prev = e + e.next = at.next + e.prev.next = e + e.next.prev = e e.list = l l.len++ return e } // insertValue is a convenience wrapper for insert(&Element{Value: v}, at). -func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement { - return l.insert(&ByteIntervalElement{Value: v}, at) +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + var e *Element[T] + if l.pool != nil { + e = l.pool.Get().(*Element[T]) + } else { + e = &Element[T]{} + } + e.Value = v + return l.insert(e, at) } -// remove removes e from its list, decrements l.len, and returns e. -func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement { +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { e.prev.next = e.next e.next.prev = e.prev e.next = nil // avoid memory leaks e.prev = nil // avoid memory leaks e.list = nil + if l.pool != nil { + l.pool.Put(e) + } l.len-- - return e +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e } // Remove removes e from l if e is an element of list l. // It returns the element value e.Value. // The element must not be nil. -func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval { +func (l *List[T]) Remove(e *Element[T]) T { + v := e.Value if e.list == l { // if e.list == l, l must have been initialized when e was inserted // in l or l == nil (e is a zero Element) and l.remove will crash l.remove(e) } - return e.Value + return v } // PushFront inserts a new element e with value v at the front of list l and returns e. -func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement { +func (l *List[T]) PushFront(v T) *Element[T] { l.lazyInit() return l.insertValue(v, &l.root) } // PushBack inserts a new element e with value v at the back of list l and returns e. -func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { +func (l *List[T]) PushBack(v T) *Element[T] { l.lazyInit() return l.insertValue(v, l.root.prev) } @@ -137,7 +184,7 @@ func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement { // InsertBefore inserts a new element e with value v immediately before mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. -func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { if mark.list != l { return nil } @@ -148,7 +195,7 @@ func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElemen // InsertAfter inserts a new element e with value v immediately after mark and returns e. // If mark is not an element of l, the list is not modified. // The mark must not be nil. -func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement { +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { if mark.list != l { return nil } @@ -159,57 +206,57 @@ func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. -func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) { +func (l *List[T]) MoveToFront(e *Element[T]) { if e.list != l || l.root.next == e { return } // see comment in List.Remove about initialization of l - l.insert(l.remove(e), &l.root) + l.move(e, &l.root) } // MoveToBack moves element e to the back of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. -func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) { +func (l *List[T]) MoveToBack(e *Element[T]) { if e.list != l || l.root.prev == e { return } // see comment in List.Remove about initialization of l - l.insert(l.remove(e), l.root.prev) + l.move(e, l.root.prev) } // MoveBefore moves element e to its new position before mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. -func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) { +func (l *List[T]) MoveBefore(e, mark *Element[T]) { if e.list != l || e == mark || mark.list != l { return } - l.insert(l.remove(e), mark.prev) + l.move(e, mark.prev) } // MoveAfter moves element e to its new position after mark. // If e or mark is not an element of l, or e == mark, the list is not modified. // The element and mark must not be nil. -func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) { +func (l *List[T]) MoveAfter(e, mark *Element[T]) { if e.list != l || e == mark || mark.list != l { return } - l.insert(l.remove(e), mark) + l.move(e, mark) } -// PushBackList inserts a copy of an other list at the back of list l. +// PushBackList inserts a copy of another list at the back of list l. // The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) { +func (l *List[T]) PushBackList(other *List[T]) { l.lazyInit() for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { l.insertValue(e.Value, l.root.prev) } } -// PushFrontList inserts a copy of an other list at the front of list l. +// PushFrontList inserts a copy of another list at the front of list l. // The lists l and other may be the same. They must not be nil. -func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) { +func (l *List[T]) PushFrontList(other *List[T]) { l.lazyInit() for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { l.insertValue(e.Value, &l.root) diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go b/vendor/github.com/quic-go/quic-go/internal/utils/log.go similarity index 98% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go rename to vendor/github.com/quic-go/quic-go/internal/utils/log.go index e27f01b4a84..89b52c0d9a7 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/log.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/log.go @@ -125,7 +125,7 @@ func readLoggingEnv() LogLevel { case "error": return LogLevelError default: - fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging") + fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/quic-go/quic-go/wiki/Logging") return LogLevelNothing } } diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go b/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go new file mode 100644 index 00000000000..d191f751583 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/utils/minmax.go @@ -0,0 +1,72 @@ +package utils + +import ( + "math" + "time" + + "golang.org/x/exp/constraints" +) + +// InfDuration is a duration of infinite length +const InfDuration = time.Duration(math.MaxInt64) + +func Max[T constraints.Ordered](a, b T) T { + if a < b { + return b + } + return a +} + +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} + +// MinNonZeroDuration return the minimum duration that's not zero. +func MinNonZeroDuration(a, b time.Duration) time.Duration { + if a == 0 { + return b + } + if b == 0 { + return a + } + return Min(a, b) +} + +// AbsDuration returns the absolute value of a time duration +func AbsDuration(d time.Duration) time.Duration { + if d >= 0 { + return d + } + return -d +} + +// MinTime returns the earlier time +func MinTime(a, b time.Time) time.Time { + if a.After(b) { + return b + } + return a +} + +// MinNonZeroTime returns the earlist time that is not time.Time{} +// If both a and b are time.Time{}, it returns time.Time{} +func MinNonZeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b + } + if b.IsZero() { + return a + } + return MinTime(a, b) +} + +// MaxTime returns the later time +func MaxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go b/vendor/github.com/quic-go/quic-go/internal/utils/rand.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/rand.go rename to vendor/github.com/quic-go/quic-go/internal/utils/rand.go diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go similarity index 92% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go rename to vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index 66642ba8f97..527539e1e2d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -3,7 +3,7 @@ package utils import ( "time" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) const ( @@ -55,7 +55,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if r.SmoothedRTT() == 0 { return 2 * defaultInitialRTT } - pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity) + pto := r.SmoothedRTT() + Max(4*r.MeanDeviation(), protocol.TimerGranularity) if includeMaxAckDelay { pto += r.MaxAckDelay() } @@ -122,6 +122,6 @@ func (r *RTTStats) OnConnectionMigration() { // is larger. The mean deviation is increased to the most recent deviation if // it's larger. func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT) + r.meanDeviation = Max(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) + r.smoothedRTT = Max(r.smoothedRTT, r.latestRTT) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go b/vendor/github.com/quic-go/quic-go/internal/utils/timer.go similarity index 94% rename from vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go rename to vendor/github.com/quic-go/quic-go/internal/utils/timer.go index a4f5e67aa0e..361106c8a95 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/utils/timer.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/timer.go @@ -47,6 +47,10 @@ func (t *Timer) SetRead() { t.read = true } +func (t *Timer) Deadline() time.Time { + return t.deadline +} + // Stop stops the timer func (t *Timer) Stop() { t.t.Stop() diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go similarity index 85% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go index e5280a73726..f145c8b4d92 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go @@ -6,9 +6,9 @@ import ( "sort" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/quicvarint" ) var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges") @@ -22,14 +22,10 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - ecn := typeByte&0x1 > 0 +func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { + ecn := typ == ackECNFrameType - frame := &AckFrame{} + frame := GetAckFrame() la, err := quicvarint.Read(r) if err != nil { @@ -106,41 +102,41 @@ func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNu return frame, nil } -// Write writes an ACK frame. -func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +// Append appends an ACK frame. +func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { - b.WriteByte(0x3) + b = append(b, ackECNFrameType) } else { - b.WriteByte(0x2) + b = append(b, ackFrameType) } - quicvarint.Write(b, uint64(f.LargestAcked())) - quicvarint.Write(b, encodeAckDelay(f.DelayTime)) + b = quicvarint.Append(b, uint64(f.LargestAcked())) + b = quicvarint.Append(b, encodeAckDelay(f.DelayTime)) numRanges := f.numEncodableAckRanges() - quicvarint.Write(b, uint64(numRanges-1)) + b = quicvarint.Append(b, uint64(numRanges-1)) // write the first range _, firstRange := f.encodeAckRange(0) - quicvarint.Write(b, firstRange) + b = quicvarint.Append(b, firstRange) // write all the other range for i := 1; i < numRanges; i++ { gap, len := f.encodeAckRange(i) - quicvarint.Write(b, gap) - quicvarint.Write(b, len) + b = quicvarint.Append(b, gap) + b = quicvarint.Append(b, len) } if hasECN { - quicvarint.Write(b, f.ECT0) - quicvarint.Write(b, f.ECT1) - quicvarint.Write(b, f.ECNCE) + b = quicvarint.Append(b, f.ECT0) + b = quicvarint.Append(b, f.ECT1) + b = quicvarint.Append(b, f.ECNCE) } - return nil + return b, nil } // Length of a written frame -func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { largestAcked := f.AckRanges[0].Largest numRanges := f.numEncodableAckRanges() diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go new file mode 100644 index 00000000000..a0c6a21d7d3 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go @@ -0,0 +1,24 @@ +package wire + +import "sync" + +var ackFramePool = sync.Pool{New: func() any { + return &AckFrame{} +}} + +func GetAckFrame() *AckFrame { + f := ackFramePool.Get().(*AckFrame) + f.AckRanges = f.AckRanges[:0] + f.ECNCE = 0 + f.ECT0 = 0 + f.ECT1 = 0 + f.DelayTime = 0 + return f +} + +func PutAckFrame(f *AckFrame) { + if cap(f.AckRanges) > 4 { + return + } + ackFramePool.Put(f) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_range.go similarity index 82% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go rename to vendor/github.com/quic-go/quic-go/internal/wire/ack_range.go index 0f4185801dc..03a1235ee0d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_range.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ack_range.go @@ -1,6 +1,6 @@ package wire -import "github.com/lucas-clemente/quic-go/internal/protocol" +import "github.com/quic-go/quic-go/internal/protocol" // AckRange is an ACK range type AckRange struct { diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go similarity index 70% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go index 4ce49af6e42..f56c2c0d845 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/connection_close_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/connection_close_frame.go @@ -4,8 +4,8 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A ConnectionCloseFrame is a CONNECTION_CLOSE frame @@ -16,13 +16,8 @@ type ConnectionCloseFrame struct { ReasonPhrase string } -func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d} +func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { + f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType} ec, err := quicvarint.Read(r) if err != nil { return nil, err @@ -66,18 +61,18 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount return length } -func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { if f.IsApplicationError { - b.WriteByte(0x1d) + b = append(b, applicationCloseFrameType) } else { - b.WriteByte(0x1c) + b = append(b, connectionCloseFrameType) } - quicvarint.Write(b, f.ErrorCode) + b = quicvarint.Append(b, f.ErrorCode) if !f.IsApplicationError { - quicvarint.Write(b, f.FrameType) + b = quicvarint.Append(b, f.FrameType) } - quicvarint.Write(b, uint64(len(f.ReasonPhrase))) - b.WriteString(f.ReasonPhrase) - return nil + b = quicvarint.Append(b, uint64(len(f.ReasonPhrase))) + b = append(b, []byte(f.ReasonPhrase)...) + return b, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go index 6301c8783dd..0f005c5ba4b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/crypto_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/crypto_frame.go @@ -4,8 +4,8 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A CryptoFrame is a CRYPTO frame @@ -15,10 +15,6 @@ type CryptoFrame struct { } func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &CryptoFrame{} offset, err := quicvarint.Read(r) if err != nil { @@ -42,12 +38,12 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, return frame, nil } -func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x6) - quicvarint.Write(b, uint64(f.Offset)) - quicvarint.Write(b, uint64(len(f.Data))) - b.Write(f.Data) - return nil +func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, cryptoFrameType) + b = quicvarint.Append(b, uint64(f.Offset)) + b = quicvarint.Append(b, uint64(len(f.Data))) + b = append(b, f.Data...) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go similarity index 53% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go index 459f04d12ec..0d4d1f5657f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/data_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/data_blocked_frame.go @@ -3,8 +3,8 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A DataBlockedFrame is a DATA_BLOCKED frame @@ -13,23 +13,16 @@ type DataBlockedFrame struct { } func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } offset, err := quicvarint.Read(r) if err != nil { return nil, err } - return &DataBlockedFrame{ - MaximumData: protocol.ByteCount(offset), - }, nil + return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil } -func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - typeByte := uint8(0x14) - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil +func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { + b = append(b, dataBlockedFrameType) + return quicvarint.Append(b, uint64(f.MaximumData)), nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go index 9d6e55cb0e0..e6c45196181 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/datagram_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/datagram_frame.go @@ -4,8 +4,8 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A DatagramFrame is a DATAGRAM frame @@ -14,14 +14,9 @@ type DatagramFrame struct { Data []byte } -func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - +func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) { f := &DatagramFrame{} - f.DataLenPresent = typeByte&0x1 > 0 + f.DataLenPresent = typ&0x1 > 0 var length uint64 if f.DataLenPresent { @@ -44,17 +39,17 @@ func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFra return f, nil } -func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - typeByte := uint8(0x30) +func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + typ := uint8(0x30) if f.DataLenPresent { - typeByte ^= 0x1 + typ ^= 0b1 } - b.WriteByte(typeByte) + b = append(b, typ) if f.DataLenPresent { - quicvarint.Write(b, uint64(len(f.Data))) + b = quicvarint.Append(b, uint64(len(f.Data))) } - b.Write(f.Data) - return nil + b = append(b, f.Data...) + return b, nil } // MaxDataLen returns the maximum data length diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go b/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go new file mode 100644 index 00000000000..d10820d6d90 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/extended_header.go @@ -0,0 +1,210 @@ +package wire + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/quicvarint" +) + +// ErrInvalidReservedBits is returned when the reserved bits are incorrect. +// When this error is returned, parsing continues, and an ExtendedHeader is returned. +// This is necessary because we need to decrypt the packet in that case, +// in order to avoid a timing side-channel. +var ErrInvalidReservedBits = errors.New("invalid reserved bits") + +// ExtendedHeader is the header of a QUIC packet. +type ExtendedHeader struct { + Header + + typeByte byte + + KeyPhase protocol.KeyPhaseBit + + PacketNumberLen protocol.PacketNumberLen + PacketNumber protocol.PacketNumber + + parsedLen protocol.ByteCount +} + +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { + startLen := b.Len() + // read the (now unencrypted) first byte + var err error + h.typeByte, err = b.ReadByte() + if err != nil { + return false, err + } + if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil { + return false, err + } + reservedBitsValid, err := h.parseLongHeader(b, v) + if err != nil { + return false, err + } + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return reservedBitsValid, err +} + +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { + if err := h.readPacketNumber(b); err != nil { + return false, err + } + if h.typeByte&0xc != 0 { + return false, nil + } + return true, nil +} + +func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { + h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1 + switch h.PacketNumberLen { + case protocol.PacketNumberLen1: + n, err := b.ReadByte() + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen2: + n, err := utils.BigEndian.ReadUint16(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen3: + n, err := utils.BigEndian.ReadUint24(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + case protocol.PacketNumberLen4: + n, err := utils.BigEndian.ReadUint32(b) + if err != nil { + return err + } + h.PacketNumber = protocol.PacketNumber(n) + default: + return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) + } + return nil +} + +// Append appends the Header. +func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) { + if h.DestConnectionID.Len() > protocol.MaxConnIDLen { + return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) + } + if h.SrcConnectionID.Len() > protocol.MaxConnIDLen { + return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len()) + } + + var packetType uint8 + if v == protocol.Version2 { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b01 + case protocol.PacketType0RTT: + packetType = 0b10 + case protocol.PacketTypeHandshake: + packetType = 0b11 + case protocol.PacketTypeRetry: + packetType = 0b00 + } + } else { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b00 + case protocol.PacketType0RTT: + packetType = 0b01 + case protocol.PacketTypeHandshake: + packetType = 0b10 + case protocol.PacketTypeRetry: + packetType = 0b11 + } + } + firstByte := 0xc0 | packetType<<4 + if h.Type != protocol.PacketTypeRetry { + // Retry packets don't have a packet number + firstByte |= uint8(h.PacketNumberLen - 1) + } + + b = append(b, firstByte) + b = append(b, make([]byte, 4)...) + binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version)) + b = append(b, uint8(h.DestConnectionID.Len())) + b = append(b, h.DestConnectionID.Bytes()...) + b = append(b, uint8(h.SrcConnectionID.Len())) + b = append(b, h.SrcConnectionID.Bytes()...) + + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeRetry: + b = append(b, h.Token...) + return b, nil + case protocol.PacketTypeInitial: + b = quicvarint.Append(b, uint64(len(h.Token))) + b = append(b, h.Token...) + } + b = quicvarint.AppendWithLen(b, uint64(h.Length), 2) + return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen) +} + +// ParsedLen returns the number of bytes that were consumed when parsing the header +func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { + return h.parsedLen +} + +// GetLength determines the length of the Header. +func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount { + length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ + if h.Type == protocol.PacketTypeInitial { + length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) + } + return length +} + +// Log logs the Header +func (h *ExtendedHeader) Log(logger utils.Logger) { + var token string + if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry { + if len(h.Token) == 0 { + token = "Token: (empty), " + } else { + token = fmt.Sprintf("Token: %#x, ", h.Token) + } + if h.Type == protocol.PacketTypeRetry { + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version) + return + } + } + logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version) +} + +func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) { + switch pnLen { + case protocol.PacketNumberLen1: + b = append(b, uint8(pn)) + case protocol.PacketNumberLen2: + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(pn)) + b = append(b, buf...) + case protocol.PacketNumberLen3: + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(pn)) + b = append(b, buf[1:]...) + case protocol.PacketNumberLen4: + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(pn)) + b = append(b, buf...) + default: + return nil, fmt.Errorf("invalid packet number length: %d", pnLen) + } + return b, nil +} diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go new file mode 100644 index 00000000000..e624df94e80 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go @@ -0,0 +1,185 @@ +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/quicvarint" +) + +const ( + pingFrameType = 0x1 + ackFrameType = 0x2 + ackECNFrameType = 0x3 + resetStreamFrameType = 0x4 + stopSendingFrameType = 0x5 + cryptoFrameType = 0x6 + newTokenFrameType = 0x7 + maxDataFrameType = 0x10 + maxStreamDataFrameType = 0x11 + bidiMaxStreamsFrameType = 0x12 + uniMaxStreamsFrameType = 0x13 + dataBlockedFrameType = 0x14 + streamDataBlockedFrameType = 0x15 + bidiStreamBlockedFrameType = 0x16 + uniStreamBlockedFrameType = 0x17 + newConnectionIDFrameType = 0x18 + retireConnectionIDFrameType = 0x19 + pathChallengeFrameType = 0x1a + pathResponseFrameType = 0x1b + connectionCloseFrameType = 0x1c + applicationCloseFrameType = 0x1d + handshakeDoneFrameType = 0x1e +) + +type frameParser struct { + r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them + + ackDelayExponent uint8 + + supportsDatagrams bool +} + +var _ FrameParser = &frameParser{} + +// NewFrameParser creates a new frame parser. +func NewFrameParser(supportsDatagrams bool) *frameParser { + return &frameParser{ + r: *bytes.NewReader(nil), + supportsDatagrams: supportsDatagrams, + } +} + +// ParseNext parses the next frame. +// It skips PADDING frames. +func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) { + startLen := len(data) + p.r.Reset(data) + frame, err := p.parseNext(&p.r, encLevel, v) + n := startLen - p.r.Len() + p.r.Reset(nil) + return n, frame, err +} + +func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { + for r.Len() != 0 { + typ, err := quicvarint.Read(r) + if err != nil { + return nil, &qerr.TransportError{ + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } + } + if typ == 0x0 { // skip PADDING frames + continue + } + + f, err := p.parseFrame(r, typ, encLevel, v) + if err != nil { + return nil, &qerr.TransportError{ + FrameType: typ, + ErrorCode: qerr.FrameEncodingError, + ErrorMessage: err.Error(), + } + } + return f, nil + } + return nil, nil +} + +func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { + var frame Frame + var err error + if typ&0xf8 == 0x8 { + frame, err = parseStreamFrame(r, typ, v) + } else { + switch typ { + case pingFrameType: + frame = &PingFrame{} + case ackFrameType, ackECNFrameType: + ackDelayExponent := p.ackDelayExponent + if encLevel != protocol.Encryption1RTT { + ackDelayExponent = protocol.DefaultAckDelayExponent + } + frame, err = parseAckFrame(r, typ, ackDelayExponent, v) + case resetStreamFrameType: + frame, err = parseResetStreamFrame(r, v) + case stopSendingFrameType: + frame, err = parseStopSendingFrame(r, v) + case cryptoFrameType: + frame, err = parseCryptoFrame(r, v) + case newTokenFrameType: + frame, err = parseNewTokenFrame(r, v) + case maxDataFrameType: + frame, err = parseMaxDataFrame(r, v) + case maxStreamDataFrameType: + frame, err = parseMaxStreamDataFrame(r, v) + case bidiMaxStreamsFrameType, uniMaxStreamsFrameType: + frame, err = parseMaxStreamsFrame(r, typ, v) + case dataBlockedFrameType: + frame, err = parseDataBlockedFrame(r, v) + case streamDataBlockedFrameType: + frame, err = parseStreamDataBlockedFrame(r, v) + case bidiStreamBlockedFrameType, uniStreamBlockedFrameType: + frame, err = parseStreamsBlockedFrame(r, typ, v) + case newConnectionIDFrameType: + frame, err = parseNewConnectionIDFrame(r, v) + case retireConnectionIDFrameType: + frame, err = parseRetireConnectionIDFrame(r, v) + case pathChallengeFrameType: + frame, err = parsePathChallengeFrame(r, v) + case pathResponseFrameType: + frame, err = parsePathResponseFrame(r, v) + case connectionCloseFrameType, applicationCloseFrameType: + frame, err = parseConnectionCloseFrame(r, typ, v) + case handshakeDoneFrameType: + frame = &HandshakeDoneFrame{} + case 0x30, 0x31: + if p.supportsDatagrams { + frame, err = parseDatagramFrame(r, typ, v) + break + } + fallthrough + default: + err = errors.New("unknown frame type") + } + } + if err != nil { + return nil, err + } + if !p.isAllowedAtEncLevel(frame, encLevel) { + return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel) + } + return frame, nil +} + +func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { + switch encLevel { + case protocol.EncryptionInitial, protocol.EncryptionHandshake: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame: + return true + default: + return false + } + case protocol.Encryption0RTT: + switch f.(type) { + case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame: + return false + default: + return true + } + case protocol.Encryption1RTT: + return true + default: + panic("unknown encryption level") + } +} + +func (p *frameParser) SetAckDelayExponent(exp uint8) { + p.ackDelayExponent = exp +} diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go new file mode 100644 index 00000000000..29521bc9e9f --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/handshake_done_frame.go @@ -0,0 +1,17 @@ +package wire + +import ( + "github.com/quic-go/quic-go/internal/protocol" +) + +// A HandshakeDoneFrame is a HANDSHAKE_DONE frame +type HandshakeDoneFrame struct{} + +func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + return append(b, handshakeDoneFrameType), nil +} + +// Length of a written frame +func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go b/vendor/github.com/quic-go/quic-go/internal/wire/header.go similarity index 60% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go rename to vendor/github.com/quic-go/quic-go/internal/wire/header.go index f6a31ee0ec4..6e8d4f9f664 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/header.go @@ -7,9 +7,9 @@ import ( "fmt" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/quicvarint" ) // ParseConnectionID parses the destination connection ID of a packet. @@ -17,23 +17,77 @@ import ( // That means that the connection ID must not be used after the packet buffer is released. func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { if len(data) == 0 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - isLongHeader := data[0]&0x80 > 0 - if !isLongHeader { + if !IsLongHeaderPacket(data[0]) { if len(data) < shortHeaderConnIDLen+1 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil } if len(data) < 6 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } destConnIDLen := int(data[5]) + if destConnIDLen > protocol.MaxConnIDLen { + return protocol.ConnectionID{}, protocol.ErrInvalidConnectionIDLen + } if len(data) < 6+destConnIDLen { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF + } + return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil +} + +// ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet, +// using only the version-independent packet format as described in Section 5.1 of RFC 8999: +// https://datatracker.ietf.org/doc/html/rfc8999#section-5.1. +// This function should only be called on Long Header packets for which we don't support the version. +func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) { + r := bytes.NewReader(data) + remaining := r.Len() + src, dest, err := parseArbitraryLenConnectionIDs(r) + return remaining - r.Len(), src, dest, err +} + +func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) { + r.Seek(5, io.SeekStart) // skip first byte and version field + destConnIDLen, err := r.ReadByte() + if err != nil { + return nil, nil, err + } + destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen) + if _, err := io.ReadFull(r, destConnID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + return nil, nil, err + } + srcConnIDLen, err := r.ReadByte() + if err != nil { + return nil, nil, err + } + srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen) + if _, err := io.ReadFull(r, srcConnID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + return nil, nil, err + } + return destConnID, srcConnID, nil +} + +// IsLongHeaderPacket says if this is a Long Header packet +func IsLongHeaderPacket(firstByte byte) bool { + return firstByte&0x80 > 0 +} + +// ParseVersion parses the QUIC version. +// It should only be called for Long Header packets (Short Header packets don't contain a version number). +func ParseVersion(data []byte) (protocol.VersionNumber, error) { + if len(data) < 5 { + return 0, io.EOF } - return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil + return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil } // IsVersionNegotiationPacket says if this is a version negotiation packet @@ -41,7 +95,7 @@ func IsVersionNegotiationPacket(b []byte) bool { if len(b) < 5 { return false } - return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 + return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 } // Is0RTTPacket says if this is a 0-RTT packet. @@ -50,26 +104,27 @@ func Is0RTTPacket(b []byte) bool { if len(b) < 5 { return false } - if b[0]&0x80 == 0 { + if !IsLongHeaderPacket(b[0]) { return false } version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) - if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { - return false - } - if version == protocol.Version2 { + //nolint:exhaustive // We only need to test QUIC versions that we support. + switch version { + case protocol.Version1, protocol.VersionDraft29: + return b[0]>>4&0b11 == 0b01 + case protocol.Version2: return b[0]>>4&0b11 == 0b10 + default: + return false } - return b[0]>>4&0b11 == 0b01 } var ErrUnsupportedVersion = errors.New("unsupported version") // The Header is the version independent part of the header type Header struct { - IsLongHeader bool - typeByte byte - Type protocol.PacketType + typeByte byte + Type protocol.PacketType Version protocol.VersionNumber SrcConnectionID protocol.ConnectionID @@ -86,24 +141,22 @@ type Header struct { // If the packet has a long header, the packet is cut according to the length field. // If we understand the version, the packet is header up unto the packet number. // Otherwise, only the invariant part of the header is parsed. -func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) { - hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen) +func ParsePacket(data []byte) (*Header, []byte, []byte, error) { + if len(data) == 0 || !IsLongHeaderPacket(data[0]) { + return nil, nil, nil, errors.New("not a long header packet") + } + hdr, err := parseHeader(bytes.NewReader(data)) if err != nil { if err == ErrUnsupportedVersion { return hdr, nil, nil, ErrUnsupportedVersion } return nil, nil, nil, err } - var rest []byte - if hdr.IsLongHeader { - if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { - return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - packetLen := int(hdr.ParsedLen() + hdr.Length) - rest = data[packetLen:] - data = data[:packetLen] + if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length { + return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) } - return hdr, data, rest, nil + packetLen := int(hdr.ParsedLen() + hdr.Length) + return hdr, data[:packetLen], data[packetLen:], nil } // ParseHeader parses the header. @@ -111,43 +164,17 @@ func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* pack // For long header packets: // * if we understand the version: up to the packet number // * if not, only the invariant part of the header -func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { +func parseHeader(b *bytes.Reader) (*Header, error) { startLen := b.Len() - h, err := parseHeaderImpl(b, shortHeaderConnIDLen) - if err != nil { - return h, err - } - h.parsedLen = protocol.ByteCount(startLen - b.Len()) - return h, err -} - -func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) { typeByte, err := b.ReadByte() if err != nil { return nil, err } - h := &Header{ - typeByte: typeByte, - IsLongHeader: typeByte&0x80 > 0, - } - - if !h.IsLongHeader { - if h.typeByte&0x40 == 0 { - return nil, errors.New("not a QUIC packet") - } - if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil { - return nil, err - } - return h, nil - } - return h, h.parseLongHeader(b) -} - -func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error { - var err error - h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen) - return err + h := &Header{typeByte: typeByte} + err = h.parseLongHeader(b) + h.parsedLen = protocol.ByteCount(startLen - b.Len()) + return h, err } func (h *Header) parseLongHeader(b *bytes.Reader) error { @@ -267,8 +294,5 @@ func (h *Header) toExtendedHeader() *ExtendedHeader { // PacketType is the type of the packet, for logging purposes func (h *Header) PacketType() string { - if h.IsLongHeader { - return h.Type.String() - } - return "1-RTT" + return h.Type.String() } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go b/vendor/github.com/quic-go/quic-go/internal/wire/interface.go similarity index 53% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go rename to vendor/github.com/quic-go/quic-go/internal/wire/interface.go index 99fdc80fb24..7e0f9a03e0f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/interface.go @@ -1,19 +1,17 @@ package wire import ( - "bytes" - - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // A Frame in QUIC type Frame interface { - Write(b *bytes.Buffer, version protocol.VersionNumber) error + Append(b []byte, version protocol.VersionNumber) ([]byte, error) Length(version protocol.VersionNumber) protocol.ByteCount } // A FrameParser parses QUIC frames, one by one. type FrameParser interface { - ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error) + ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error) SetAckDelayExponent(uint8) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go b/vendor/github.com/quic-go/quic-go/internal/wire/log.go similarity index 96% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go rename to vendor/github.com/quic-go/quic-go/internal/wire/log.go index 30cf94243f7..ec7d45d861c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/log.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/log.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // LogFrame logs a frame, either sent or received diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go similarity index 55% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go index a9a092482b1..e61b0f9f3b0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_data_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_data_frame.go @@ -3,8 +3,8 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A MaxDataFrame carries flow control information for the connection @@ -14,10 +14,6 @@ type MaxDataFrame struct { // parseMaxDataFrame parses a MAX_DATA frame func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - frame := &MaxDataFrame{} byteOffset, err := quicvarint.Read(r) if err != nil { @@ -27,14 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame return frame, nil } -// Write writes a MAX_STREAM_DATA frame -func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x10) - quicvarint.Write(b, uint64(f.MaximumData)) - return nil +func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, maxDataFrameType) + b = quicvarint.Append(b, uint64(f.MaximumData)) + return b, nil } // Length of a written frame -func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go similarity index 67% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go index 728ecbe8b92..fe3d1e3f70b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_stream_data_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_stream_data_frame.go @@ -3,8 +3,8 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A MaxStreamDataFrame is a MAX_STREAM_DATA frame @@ -14,10 +14,6 @@ type MaxStreamDataFrame struct { } func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -33,11 +29,11 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr }, nil } -func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x11) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil +func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { + b = append(b, maxStreamDataFrameType) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.MaximumStreamData)) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go similarity index 61% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go index 73d7e13eaae..bd278c02a87 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/max_streams_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/max_streams_frame.go @@ -4,8 +4,8 @@ import ( "bytes" "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A MaxStreamsFrame is a MAX_STREAMS frame @@ -14,17 +14,12 @@ type MaxStreamsFrame struct { MaxStreamNum protocol.StreamNum } -func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - +func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { f := &MaxStreamsFrame{} - switch typeByte { - case 0x12: + switch typ { + case bidiMaxStreamsFrameType: f.Type = protocol.StreamTypeBidi - case 0x13: + case uniMaxStreamsFrameType: f.Type = protocol.StreamTypeUni } streamID, err := quicvarint.Read(r) @@ -38,15 +33,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream return f, nil } -func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b.WriteByte(0x12) + b = append(b, bidiMaxStreamsFrameType) case protocol.StreamTypeUni: - b.WriteByte(0x13) + b = append(b, uniMaxStreamsFrameType) } - quicvarint.Write(b, uint64(f.MaxStreamNum)) - return nil + b = quicvarint.Append(b, uint64(f.MaxStreamNum)) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go index 1a017ba991d..83102d5d144 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_connection_id_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/new_connection_id_frame.go @@ -5,8 +5,8 @@ import ( "fmt" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A NewConnectionIDFrame is a NEW_CONNECTION_ID frame @@ -18,10 +18,6 @@ type NewConnectionIDFrame struct { } func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -38,9 +34,6 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC if err != nil { return nil, err } - if connIDLen > protocol.MaxConnIDLen { - return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) - } connID, err := protocol.ReadConnectionID(r, int(connIDLen)) if err != nil { return nil, err @@ -60,18 +53,18 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC return frame, nil } -func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x18) - quicvarint.Write(b, f.SequenceNumber) - quicvarint.Write(b, f.RetirePriorTo) +func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, newConnectionIDFrameType) + b = quicvarint.Append(b, f.SequenceNumber) + b = quicvarint.Append(b, f.RetirePriorTo) connIDLen := f.ConnectionID.Len() if connIDLen > protocol.MaxConnIDLen { - return fmt.Errorf("invalid connection ID length: %d", connIDLen) + return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) } - b.WriteByte(uint8(connIDLen)) - b.Write(f.ConnectionID.Bytes()) - b.Write(f.StatelessResetToken[:]) - return nil + b = append(b, uint8(connIDLen)) + b = append(b, f.ConnectionID.Bytes()...) + b = append(b, f.StatelessResetToken[:]...) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go similarity index 69% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go index 3d5d5c3a10f..c3fa178c17f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/new_token_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/new_token_frame.go @@ -5,8 +5,8 @@ import ( "errors" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A NewTokenFrame is a NEW_TOKEN frame @@ -15,9 +15,6 @@ type NewTokenFrame struct { } func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } tokenLen, err := quicvarint.Read(r) if err != nil { return nil, err @@ -35,11 +32,11 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra return &NewTokenFrame{Token: token}, nil } -func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x7) - quicvarint.Write(b, uint64(len(f.Token))) - b.Write(f.Token) - return nil +func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, newTokenFrameType) + b = quicvarint.Append(b, uint64(len(f.Token))) + b = append(b, f.Token...) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go similarity index 69% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go index 5ec8217723c..ad024330aca 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_challenge_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/path_challenge_frame.go @@ -4,7 +4,7 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // A PathChallengeFrame is a PATH_CHALLENGE frame @@ -13,9 +13,6 @@ type PathChallengeFrame struct { } func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } frame := &PathChallengeFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -26,10 +23,10 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh return frame, nil } -func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1a) - b.Write(f.Data[:]) - return nil +func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, pathChallengeFrameType) + b = append(b, f.Data[:]...) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go similarity index 68% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go index 262819f8989..76e65104963 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/path_response_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/path_response_frame.go @@ -4,7 +4,7 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // A PathResponseFrame is a PATH_RESPONSE frame @@ -13,9 +13,6 @@ type PathResponseFrame struct { } func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } frame := &PathResponseFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -26,10 +23,10 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes return frame, nil } -func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x1b) - b.Write(f.Data[:]) - return nil +func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, pathResponseFrameType) + b = append(b, f.Data[:]...) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go new file mode 100644 index 00000000000..dd24edc0c37 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ping_frame.go @@ -0,0 +1,17 @@ +package wire + +import ( + "github.com/quic-go/quic-go/internal/protocol" +) + +// A PingFrame is a PING frame +type PingFrame struct{} + +func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + return append(b, pingFrameType), nil +} + +// Length of a written frame +func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { + return 1 +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go b/vendor/github.com/quic-go/quic-go/internal/wire/pool.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go rename to vendor/github.com/quic-go/quic-go/internal/wire/pool.go index c057395e768..18ab437937d 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/pool.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/pool.go @@ -3,7 +3,7 @@ package wire import ( "sync" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) var pool sync.Pool diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go similarity index 68% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go index 69bbc2b9d89..cd94c9408fc 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/reset_stream_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/reset_stream_frame.go @@ -3,9 +3,9 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/quicvarint" ) // A ResetStreamFrame is a RESET_STREAM frame in QUIC @@ -16,10 +16,6 @@ type ResetStreamFrame struct { } func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { - if _, err := r.ReadByte(); err != nil { // read the TypeByte - return nil, err - } - var streamID protocol.StreamID var byteOffset protocol.ByteCount sid, err := quicvarint.Read(r) @@ -44,12 +40,12 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr }, nil } -func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x4) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - quicvarint.Write(b, uint64(f.FinalSize)) - return nil +func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, resetStreamFrameType) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.ErrorCode)) + b = quicvarint.Append(b, uint64(f.FinalSize)) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go similarity index 63% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go index 0f7e58c8775..8e9a41d89b5 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/retire_connection_id_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/retire_connection_id_frame.go @@ -3,8 +3,8 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame @@ -13,10 +13,6 @@ type RetireConnectionIDFrame struct { } func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -24,10 +20,10 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R return &RetireConnectionIDFrame{SequenceNumber: seq}, nil } -func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x19) - quicvarint.Write(b, f.SequenceNumber) - return nil +func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, retireConnectionIDFrameType) + b = quicvarint.Append(b, f.SequenceNumber) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/short_header.go b/vendor/github.com/quic-go/quic-go/internal/wire/short_header.go new file mode 100644 index 00000000000..69aa8341185 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/wire/short_header.go @@ -0,0 +1,73 @@ +package wire + +import ( + "errors" + "fmt" + "io" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" +) + +// ParseShortHeader parses a short header packet. +// It must be called after header protection was removed. +// Otherwise, the check for the reserved bits will (most likely) fail. +func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) { + if len(data) == 0 { + return 0, 0, 0, 0, io.EOF + } + if data[0]&0x80 > 0 { + return 0, 0, 0, 0, errors.New("not a short header packet") + } + if data[0]&0x40 == 0 { + return 0, 0, 0, 0, errors.New("not a QUIC packet") + } + pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1 + if len(data) < 1+int(pnLen)+connIDLen { + return 0, 0, 0, 0, io.EOF + } + + pos := 1 + connIDLen + var pn protocol.PacketNumber + switch pnLen { + case protocol.PacketNumberLen1: + pn = protocol.PacketNumber(data[pos]) + case protocol.PacketNumberLen2: + pn = protocol.PacketNumber(utils.BigEndian.Uint16(data[pos : pos+2])) + case protocol.PacketNumberLen3: + pn = protocol.PacketNumber(utils.BigEndian.Uint24(data[pos : pos+3])) + case protocol.PacketNumberLen4: + pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4])) + default: + return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen) + } + kp := protocol.KeyPhaseZero + if data[0]&0b100 > 0 { + kp = protocol.KeyPhaseOne + } + + var err error + if data[0]&0x18 != 0 { + err = ErrInvalidReservedBits + } + return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err +} + +// AppendShortHeader writes a short header. +func AppendShortHeader(b []byte, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) ([]byte, error) { + typeByte := 0x40 | uint8(pnLen-1) + if kp == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } + b = append(b, typeByte) + b = append(b, connID.Bytes()...) + return appendPacketNumber(b, pn, pnLen) +} + +func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { + return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen) +} + +func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp) +} diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go similarity index 66% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go index fb1160c1b47..d7b8b240290 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stop_sending_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stop_sending_frame.go @@ -3,9 +3,9 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/quicvarint" ) // A StopSendingFrame is a STOP_SENDING frame @@ -16,10 +16,6 @@ type StopSendingFrame struct { // parseStopSendingFrame parses a STOP_SENDING frame func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - streamID, err := quicvarint.Read(r) if err != nil { return nil, err @@ -40,9 +36,9 @@ func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) } -func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { - b.WriteByte(0x5) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.ErrorCode)) - return nil +func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, stopSendingFrameType) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.ErrorCode)) + return b, nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go similarity index 68% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go index dc6d631a573..d42e59a2d58 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_data_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stream_data_blocked_frame.go @@ -3,8 +3,8 @@ package wire import ( "bytes" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame @@ -14,10 +14,6 @@ type StreamDataBlockedFrame struct { } func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { - if _, err := r.ReadByte(); err != nil { - return nil, err - } - sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -33,11 +29,11 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St }, nil } -func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { - b.WriteByte(0x15) - quicvarint.Write(b, uint64(f.StreamID)) - quicvarint.Write(b, uint64(f.MaximumStreamData)) - return nil +func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { + b = append(b, 0x15) + b = quicvarint.Append(b, uint64(f.StreamID)) + b = quicvarint.Append(b, uint64(f.MaximumStreamData)) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go similarity index 83% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go index 66340d1697b..d22e1c0526e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/stream_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/stream_frame.go @@ -5,8 +5,8 @@ import ( "errors" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A StreamFrame of QUIC @@ -20,15 +20,10 @@ type StreamFrame struct { fromPool bool } -func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - - hasOffset := typeByte&0x4 > 0 - fin := typeByte&0x1 > 0 - hasDataLen := typeByte&0x2 > 0 +func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) { + hasOffset := typ&0b100 > 0 + fin := typ&0b1 > 0 + hasDataLen := typ&0b10 > 0 streamID, err := quicvarint.Read(r) if err != nil { @@ -84,32 +79,32 @@ func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, } // Write writes a STREAM frame -func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error { +func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { - return errors.New("StreamFrame: attempting to write empty frame without FIN") + return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") } - typeByte := byte(0x8) + typ := byte(0x8) if f.Fin { - typeByte ^= 0x1 + typ ^= 0b1 } hasOffset := f.Offset != 0 if f.DataLenPresent { - typeByte ^= 0x2 + typ ^= 0b10 } if hasOffset { - typeByte ^= 0x4 + typ ^= 0b100 } - b.WriteByte(typeByte) - quicvarint.Write(b, uint64(f.StreamID)) + b = append(b, typ) + b = quicvarint.Append(b, uint64(f.StreamID)) if hasOffset { - quicvarint.Write(b, uint64(f.Offset)) + b = quicvarint.Append(b, uint64(f.Offset)) } if f.DataLenPresent { - quicvarint.Write(b, uint64(f.DataLen())) + b = quicvarint.Append(b, uint64(f.DataLen())) } - b.Write(f.Data) - return nil + b = append(b, f.Data...) + return b, nil } // Length returns the total length of the STREAM frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go similarity index 60% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go rename to vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go index f4066071fce..4a5951c625b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/streams_blocked_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/streams_blocked_frame.go @@ -4,8 +4,8 @@ import ( "bytes" "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/quicvarint" ) // A StreamsBlockedFrame is a STREAMS_BLOCKED frame @@ -14,17 +14,12 @@ type StreamsBlockedFrame struct { StreamLimit protocol.StreamNum } -func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } - +func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { f := &StreamsBlockedFrame{} - switch typeByte { - case 0x16: + switch typ { + case bidiStreamBlockedFrameType: f.Type = protocol.StreamTypeBidi - case 0x17: + case uniStreamBlockedFrameType: f.Type = protocol.StreamTypeUni } streamLimit, err := quicvarint.Read(r) @@ -38,15 +33,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Strea return f, nil } -func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: - b.WriteByte(0x16) + b = append(b, bidiStreamBlockedFrameType) case protocol.StreamTypeUni: - b.WriteByte(0x17) + b = append(b, uniStreamBlockedFrameType) } - quicvarint.Write(b, uint64(f.StreamLimit)) - return nil + b = quicvarint.Append(b, uint64(f.StreamLimit)) + return b, nil } // Length of a written frame diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go similarity index 74% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go rename to vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go index e1f83cd6564..8eb4cf46abd 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/transport_parameters.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go @@ -2,24 +2,37 @@ package wire import ( "bytes" + "encoding/binary" "errors" "fmt" "io" "math/rand" "net" "sort" + "sync" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/quicvarint" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/quicvarint" ) +// AdditionalTransportParametersClient are additional transport parameters that will be added +// to the client's transport parameters. +// This is not intended for production use, but _only_ to increase the size of the ClientHello beyond +// the usual size of less than 1 MTU. +var AdditionalTransportParametersClient map[uint64][]byte + const transportParameterMarshalingVersion = 1 +var ( + randomMutex sync.Mutex + random rand.Rand +) + func init() { - rand.Seed(time.Now().UTC().UnixNano()) + random = *rand.New(rand.NewSource(time.Now().UnixNano())) } type transportParameterID uint64 @@ -285,7 +298,7 @@ func (p *TransportParameters) readNumericTransportParameter( return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) } case maxIdleTimeoutParameterID: - p.MaxIdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + p.MaxIdleTimeout = utils.Max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) case maxUDPPayloadSizeParameterID: if val < 1200 { return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) @@ -302,6 +315,9 @@ func (p *TransportParameters) readNumericTransportParameter( } p.MaxAckDelay = time.Duration(val) * time.Millisecond case activeConnectionIDLimitParameterID: + if val < 2 { + return fmt.Errorf("invalid value for active_connection_id_limit: %d (minimum 2)", val) + } p.ActiveConnectionIDLimit = val case maxDatagramFrameSizeParameterID: p.MaxDatagramFrameSize = protocol.ByteCount(val) @@ -313,94 +329,109 @@ func (p *TransportParameters) readNumericTransportParameter( // Marshal the transport parameters func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { - b := &bytes.Buffer{} + // Typical Transport Parameters consume around 110 bytes, depending on the exact values, + // especially the lengths of the Connection IDs. + // Allocate 256 bytes, so we won't have to grow the slice in any case. + b := make([]byte, 0, 256) // add a greased value - quicvarint.Write(b, uint64(27+31*rand.Intn(100))) - length := rand.Intn(16) - randomData := make([]byte, length) - rand.Read(randomData) - quicvarint.Write(b, uint64(length)) - b.Write(randomData) + b = quicvarint.Append(b, uint64(27+31*rand.Intn(100))) + randomMutex.Lock() + length := random.Intn(16) + b = quicvarint.Append(b, uint64(length)) + b = b[:len(b)+length] + random.Read(b[len(b)-length:]) + randomMutex.Unlock() // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // idle_timeout - p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) + b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) // max_packet_size - p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) + b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) // max_ack_delay // Only send it if is different from the default value. if p.MaxAckDelay != protocol.DefaultMaxAckDelay { - p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) + b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) } // ack_delay_exponent // Only send it if is different from the default value. if p.AckDelayExponent != protocol.DefaultAckDelayExponent { - p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) + b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) } // disable_active_migration if p.DisableActiveMigration { - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 0) + b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID)) + b = quicvarint.Append(b, 0) } if pers == protocol.PerspectiveServer { // stateless_reset_token if p.StatelessResetToken != nil { - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(p.StatelessResetToken[:]) + b = quicvarint.Append(b, uint64(statelessResetTokenParameterID)) + b = quicvarint.Append(b, 16) + b = append(b, p.StatelessResetToken[:]...) } // original_destination_connection_id - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) - b.Write(p.OriginalDestinationConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len())) + b = append(b, p.OriginalDestinationConnectionID.Bytes()...) // preferred_address if p.PreferredAddress != nil { - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + b = quicvarint.Append(b, uint64(preferredAddressParameterID)) + b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) ipv4 := p.PreferredAddress.IPv4 - b.Write(ipv4[len(ipv4)-4:]) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) - b.Write(p.PreferredAddress.IPv6) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) - b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) - b.Write(p.PreferredAddress.ConnectionID.Bytes()) - b.Write(p.PreferredAddress.StatelessResetToken[:]) + b = append(b, ipv4[len(ipv4)-4:]...) + b = append(b, []byte{0, 0}...) + binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port) + b = append(b, p.PreferredAddress.IPv6...) + b = append(b, []byte{0, 0}...) + binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port) + b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) + b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) + b = append(b, p.PreferredAddress.StatelessResetToken[:]...) } } // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) // initial_source_connection_id - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) - b.Write(p.InitialSourceConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len())) + b = append(b, p.InitialSourceConnectionID.Bytes()...) // retry_source_connection_id if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { - quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) - b.Write(p.RetrySourceConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len())) + b = append(b, p.RetrySourceConnectionID.Bytes()...) } if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) } - return b.Bytes() + + if pers == protocol.PerspectiveClient && len(AdditionalTransportParametersClient) > 0 { + for k, v := range AdditionalTransportParametersClient { + b = quicvarint.Append(b, k) + b = quicvarint.Append(b, uint64(len(v))) + b = append(b, v...) + } + } + + return b } -func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { - quicvarint.Write(b, uint64(id)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) +func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte { + b = quicvarint.Append(b, uint64(id)) + b = quicvarint.Append(b, uint64(quicvarint.Len(val))) + return quicvarint.Append(b, val) } // MarshalForSessionTicket marshals the transport parameters we save in the session ticket. @@ -411,23 +442,23 @@ func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportPa // if the transport parameters changed. // Since the session ticket is encrypted, the serialization format is defined by the server. // For convenience, we use the same format that we also use for sending the transport parameters. -func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { - quicvarint.Write(b, transportParameterMarshalingVersion) +func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { + b = quicvarint.Append(b, transportParameterMarshalingVersion) // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. diff --git a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go b/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go similarity index 50% rename from vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go rename to vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go index 196853e0fc4..3dc621135b3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/version_negotiation.go @@ -3,39 +3,38 @@ package wire import ( "bytes" "crypto/rand" + "encoding/binary" "errors" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { - hdr, err := parseHeader(b, 0) +func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { + n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - if b.Len() == 0 { + b = b[n:] + if len(b) == 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has empty version list") + return nil, nil, nil, errors.New("Version Negotiation packet has empty version list") } - if b.Len()%4 != 0 { + if len(b)%4 != 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } - versions := make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, nil, err - } - versions[i] = protocol.VersionNumber(v) + versions := make([]protocol.VersionNumber, len(b)/4) + for i := 0; len(b) > 0; i++ { + versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) + b = b[4:] } - return hdr, versions, nil + return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) @@ -44,9 +43,9 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers buf.WriteByte(r[0] | 0x80) utils.BigEndian.WriteUint32(buf, 0) // version 0 buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID) + buf.Write(destConnID.Bytes()) buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID) + buf.Write(srcConnID.Bytes()) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/frame.go b/vendor/github.com/quic-go/quic-go/logging/frame.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/logging/frame.go rename to vendor/github.com/quic-go/quic-go/logging/frame.go index 75705092e20..9a055db3595 100644 --- a/vendor/github.com/lucas-clemente/quic-go/logging/frame.go +++ b/vendor/github.com/quic-go/quic-go/logging/frame.go @@ -1,6 +1,6 @@ package logging -import "github.com/lucas-clemente/quic-go/internal/wire" +import "github.com/quic-go/quic-go/internal/wire" // A Frame is a QUIC frame type Frame interface{} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/interface.go b/vendor/github.com/quic-go/quic-go/logging/interface.go similarity index 77% rename from vendor/github.com/lucas-clemente/quic-go/logging/interface.go rename to vendor/github.com/quic-go/quic-go/logging/interface.go index f71d68f7d60..2ce8582ecb9 100644 --- a/vendor/github.com/lucas-clemente/quic-go/logging/interface.go +++ b/vendor/github.com/quic-go/quic-go/logging/interface.go @@ -3,15 +3,13 @@ package logging import ( - "context" "net" "time" - "github.com/lucas-clemente/quic-go/internal/utils" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type ( @@ -19,6 +17,8 @@ type ( ByteCount = protocol.ByteCount // A ConnectionID is a QUIC Connection ID. ConnectionID = protocol.ConnectionID + // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. + ArbitraryLenConnectionID = protocol.ArbitraryLenConnectionID // The EncryptionLevel is the encryption level of a packet. EncryptionLevel = protocol.EncryptionLevel // The KeyPhase is the key phase of the 1-RTT keys. @@ -42,7 +42,7 @@ type ( // The Header is the QUIC packet header, before removing header protection. Header = wire.Header - // The ExtendedHeader is the QUIC packet header, after removing header protection. + // The ExtendedHeader is the QUIC Long Header packet header, after removing header protection. ExtendedHeader = wire.ExtendedHeader // The TransportParameters are QUIC transport parameters. TransportParameters = wire.TransportParameters @@ -90,15 +90,18 @@ const ( StreamTypeBidi = protocol.StreamTypeBidi ) +// The ShortHeader is the QUIC Short Header packet header, after removing header protection. +type ShortHeader struct { + DestConnectionID ConnectionID + PacketNumber PacketNumber + PacketNumberLen protocol.PacketNumberLen + KeyPhase KeyPhaseBit +} + // A Tracer traces events. type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - SentPacket(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) } @@ -110,11 +113,13 @@ type ConnectionTracer interface { SentTransportParameters(*TransportParameters) ReceivedTransportParameters(*TransportParameters) RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(*Header, []VersionNumber) + SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) + SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) + ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) ReceivedRetry(*Header) - ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) - BufferedPacket(PacketType) + ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) + ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) + BufferedPacket(PacketType, ByteCount) DroppedPacket(PacketType, ByteCount, PacketDropReason) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) AcknowledgedPacket(EncryptionLevel, PacketNumber) diff --git a/vendor/github.com/quic-go/quic-go/logging/mockgen.go b/vendor/github.com/quic-go/quic-go/logging/mockgen.go new file mode 100644 index 00000000000..d5091679967 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/mockgen.go @@ -0,0 +1,4 @@ +package logging + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_connection_tracer_test.go github.com/quic-go/quic-go/logging ConnectionTracer" +//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_tracer_test.go github.com/quic-go/quic-go/logging Tracer" diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go b/vendor/github.com/quic-go/quic-go/logging/multiplex.go similarity index 80% rename from vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go rename to vendor/github.com/quic-go/quic-go/logging/multiplex.go index 8280e8cdf49..672a5cdbd86 100644 --- a/vendor/github.com/lucas-clemente/quic-go/logging/multiplex.go +++ b/vendor/github.com/quic-go/quic-go/logging/multiplex.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -23,19 +22,15 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer { return &tracerMultiplexer{tracers} } -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer +func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } + t.SentPacket(remote, hdr, size, frames) } - return NewMultiplexedConnectionTracer(connTracers...) } -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { +func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) + t.SentVersionNegotiationPacket(remote, dest, src, versions) } } @@ -98,15 +93,21 @@ func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParamet } } -func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { +func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) { + for _, t := range m.tracers { + t.SentLongHeaderPacket(hdr, size, ack, frames) + } +} + +func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) { for _, t := range m.tracers { - t.SentPacket(hdr, size, ack, frames) + t.SentShortHeaderPacket(hdr, size, ack, frames) } } -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { +func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(hdr, versions) + t.ReceivedVersionNegotiationPacket(dest, src, versions) } } @@ -116,15 +117,21 @@ func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { } } -func (m *connTracerMultiplexer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { +func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) { + for _, t := range m.tracers { + t.ReceivedLongHeaderPacket(hdr, size, frames) + } +} + +func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) { for _, t := range m.tracers { - t.ReceivedPacket(hdr, size, frames) + t.ReceivedShortHeaderPacket(hdr, size, frames) } } -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType) { +func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) { for _, t := range m.tracers { - t.BufferedPacket(typ) + t.BufferedPacket(typ, size) } } diff --git a/vendor/github.com/quic-go/quic-go/logging/null_tracer.go b/vendor/github.com/quic-go/quic-go/logging/null_tracer.go new file mode 100644 index 00000000000..de970385796 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/logging/null_tracer.go @@ -0,0 +1,58 @@ +package logging + +import ( + "net" + "time" +) + +// The NullTracer is a Tracer that does nothing. +// It is useful for embedding. +type NullTracer struct{} + +var _ Tracer = &NullTracer{} + +func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} +func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { +} +func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} + +// The NullConnectionTracer is a ConnectionTracer that does nothing. +// It is useful for embedding. +type NullConnectionTracer struct{} + +var _ ConnectionTracer = &NullConnectionTracer{} + +func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { +} + +func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { +} +func (n NullConnectionTracer) ClosedConnection(err error) {} +func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} +func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {} +func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { +} +func (n NullConnectionTracer) ReceivedRetry(*Header) {} +func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {} +func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {} +func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} +func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} + +func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { +} +func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} +func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} +func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} +func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} +func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} +func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} +func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} +func (n NullConnectionTracer) DroppedKey(KeyPhase) {} +func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} +func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {} +func (n NullConnectionTracer) LossTimerCanceled() {} +func (n NullConnectionTracer) Close() {} +func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go b/vendor/github.com/quic-go/quic-go/logging/packet_header.go similarity index 83% rename from vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go rename to vendor/github.com/quic-go/quic-go/logging/packet_header.go index ea4282fe372..6b8df58d8ac 100644 --- a/vendor/github.com/lucas-clemente/quic-go/logging/packet_header.go +++ b/vendor/github.com/quic-go/quic-go/logging/packet_header.go @@ -1,14 +1,11 @@ package logging import ( - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // PacketTypeFromHeader determines the packet type from a *wire.Header. func PacketTypeFromHeader(hdr *Header) PacketType { - if !hdr.IsLongHeader { - return PacketType1RTT - } if hdr.Version == 0 { return PacketTypeVersionNegotiation } diff --git a/vendor/github.com/lucas-clemente/quic-go/logging/types.go b/vendor/github.com/quic-go/quic-go/logging/types.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/logging/types.go rename to vendor/github.com/quic-go/quic-go/logging/types.go diff --git a/vendor/github.com/quic-go/quic-go/mockgen.go b/vendor/github.com/quic-go/quic-go/mockgen.go new file mode 100644 index 00000000000..eb700864ac0 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/mockgen.go @@ -0,0 +1,74 @@ +//go:build gomock || generate + +package quic + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" +type SendConn = sendConn + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" +type Sender = sender + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI" +type StreamI = streamI + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream" +type CryptoStream = cryptoStream + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI" +type ReceiveStreamI = receiveStreamI + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI" +type SendStreamI = sendStreamI + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter" +type StreamGetter = streamGetter + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" +type StreamSender = streamSender + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" +type CryptoDataHandler = cryptoDataHandler + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" +type FrameSource = frameSource + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" +type AckFrameSource = ackFrameSource + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" +type StreamManager = streamManager + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" +type SealingManager = sealingManager + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" +type Unpacker = unpacker + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" +type Packer = packer + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" +type MTUDiscoverer = mtuDiscoverer + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" +type ConnRunner = connRunner + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn" +type QUICConn = quicConn + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" +type PacketHandler = packetHandler + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unknown_packet_handler_test.go github.com/quic-go/quic-go UnknownPacketHandler" +type UnknownPacketHandler = unknownPacketHandler + +//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" +type PacketHandlerManager = packetHandlerManager + +// Need to use source mode for the batchConn, since reflect mode follows type aliases. +// See https://github.com/golang/mock/issues/244 for details. +// +//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore" +//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go similarity index 89% rename from vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go rename to vendor/github.com/quic-go/quic-go/mtu_discoverer.go index bf38eaac7d6..5a8484c76b6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/mtu_discoverer.go +++ b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go @@ -3,10 +3,10 @@ package quic import ( "time" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type mtuDiscoverer interface { diff --git a/vendor/github.com/quic-go/quic-go/multiplexer.go b/vendor/github.com/quic-go/quic-go/multiplexer.go new file mode 100644 index 00000000000..85f7f4034e3 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/multiplexer.go @@ -0,0 +1,75 @@ +package quic + +import ( + "fmt" + "net" + "sync" + + "github.com/quic-go/quic-go/internal/utils" +) + +var ( + connMuxerOnce sync.Once + connMuxer multiplexer +) + +type indexableConn interface{ LocalAddr() net.Addr } + +type multiplexer interface { + AddConn(conn indexableConn) + RemoveConn(indexableConn) error +} + +// The connMultiplexer listens on multiple net.PacketConns and dispatches +// incoming packets to the connection handler. +type connMultiplexer struct { + mutex sync.Mutex + + conns map[string] /* LocalAddr().String() */ indexableConn + logger utils.Logger +} + +var _ multiplexer = &connMultiplexer{} + +func getMultiplexer() multiplexer { + connMuxerOnce.Do(func() { + connMuxer = &connMultiplexer{ + conns: make(map[string]indexableConn), + logger: utils.DefaultLogger.WithPrefix("muxer"), + } + }) + return connMuxer +} + +func (m *connMultiplexer) index(addr net.Addr) string { + return addr.Network() + " " + addr.String() +} + +func (m *connMultiplexer) AddConn(c indexableConn) { + m.mutex.Lock() + defer m.mutex.Unlock() + + connIndex := m.index(c.LocalAddr()) + p, ok := m.conns[connIndex] + if ok { + // Panics if we're already listening on this connection. + // This is a safeguard because we're introducing a breaking API change, see + // https://github.com/quic-go/quic-go/issues/3727 for details. + // We'll remove this at a later time, when most users of the library have made the switch. + panic("connection already exists") // TODO: write a nice message + } + m.conns[connIndex] = p +} + +func (m *connMultiplexer) RemoveConn(c indexableConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + connIndex := m.index(c.LocalAddr()) + if _, ok := m.conns[connIndex]; !ok { + return fmt.Errorf("cannote remove connection, connection is unknown") + } + + delete(m.conns, connIndex) + return nil +} diff --git a/vendor/github.com/quic-go/quic-go/packet_handler_map.go b/vendor/github.com/quic-go/quic-go/packet_handler_map.go new file mode 100644 index 00000000000..83caa192036 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/packet_handler_map.go @@ -0,0 +1,271 @@ +package quic + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "errors" + "hash" + "io" + "net" + "sync" + "time" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" +) + +// rawConn is a connection that allow reading of a receivedPackeh. +type rawConn interface { + ReadPacket() (*receivedPacket, error) + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + LocalAddr() net.Addr + SetReadDeadline(time.Time) error + io.Closer +} + +type closePacket struct { + payload []byte + addr net.Addr + info *packetInfo +} + +type unknownPacketHandler interface { + handlePacket(*receivedPacket) + setCloseError(error) +} + +var errListenerAlreadySet = errors.New("listener already set") + +type packetHandlerMap struct { + mutex sync.Mutex + handlers map[protocol.ConnectionID]packetHandler + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler + + closed bool + closeChan chan struct{} + + enqueueClosePacket func(closePacket) + + deleteRetiredConnsAfter time.Duration + + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash + + logger utils.Logger +} + +var _ packetHandlerManager = &packetHandlerMap{} + +func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { + h := &packetHandlerMap{ + closeChan: make(chan struct{}), + handlers: make(map[protocol.ConnectionID]packetHandler), + resetTokens: make(map[protocol.StatelessResetToken]packetHandler), + deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, + enqueueClosePacket: enqueueClosePacket, + logger: logger, + } + if key != nil { + h.statelessResetHasher = hmac.New(sha256.New, key[:]) + } + if h.logger.Debug() { + go h.logUsage() + } + return h +} + +func (h *packetHandlerMap) logUsage() { + ticker := time.NewTicker(2 * time.Second) + var printedZero bool + for { + select { + case <-h.closeChan: + return + case <-ticker.C: + } + + h.mutex.Lock() + numHandlers := len(h.handlers) + numTokens := len(h.resetTokens) + h.mutex.Unlock() + // If the number tracked handlers and tokens is zero, only print it a single time. + hasZero := numHandlers == 0 && numTokens == 0 + if !hasZero || (hasZero && !printedZero) { + h.logger.Debugf("Tracking %d connection IDs and %d reset tokens.\n", numHandlers, numTokens) + printedZero = false + if hasZero { + printedZero = true + } + } + } +} + +func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { + h.mutex.Lock() + defer h.mutex.Unlock() + + handler, ok := h.handlers[id] + return handler, ok +} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[id]; ok { + h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) + return false + } + h.handlers[id] = handler + h.logger.Debugf("Adding connection ID %s.", id) + return true +} + +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[clientDestConnID]; ok { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false + } + conn, ok := fn() + if !ok { + return false + } + h.handlers[clientDestConnID] = conn + h.handlers[newConnID] = conn + h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) + return true +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + delete(h.handlers, id) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s.", id) +} + +func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { + h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + delete(h.handlers, id) + h.mutex.Unlock() + h.logger.Debugf("Removing connection ID %s after it has been retired.", id) + }) +} + +// ReplaceWithClosed is called when a connection is closed. +// Depending on which side closed the connection, we need to: +// * remote close: absorb delayed packets +// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { + var handler packetHandler + if connClosePacket != nil { + handler = newClosedLocalConn( + func(addr net.Addr, info *packetInfo) { + h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) + }, + pers, + h.logger, + ) + } else { + handler = newClosedRemoteConn(pers) + } + + h.mutex.Lock() + for _, id := range ids { + h.handlers[id] = handler + } + h.mutex.Unlock() + h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) + + time.AfterFunc(h.deleteRetiredConnsAfter, func() { + h.mutex.Lock() + handler.shutdown() + for _, id := range ids { + delete(h.handlers, id) + } + h.mutex.Unlock() + h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) + }) +} + +func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { + h.mutex.Lock() + h.resetTokens[token] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + +func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) { + h.mutex.Lock() + defer h.mutex.Unlock() + + handler, ok := h.resetTokens[token] + return handler, ok +} + +func (h *packetHandlerMap) CloseServer() { + h.mutex.Lock() + var wg sync.WaitGroup + for _, handler := range h.handlers { + if handler.getPerspective() == protocol.PerspectiveServer { + wg.Add(1) + go func(handler packetHandler) { + // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + handler.shutdown() + wg.Done() + }(handler) + } + } + h.mutex.Unlock() + wg.Wait() +} + +func (h *packetHandlerMap) Close(e error) { + h.mutex.Lock() + + if h.closed { + h.mutex.Unlock() + return + } + + close(h.closeChan) + + var wg sync.WaitGroup + for _, handler := range h.handlers { + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(handler) + } + h.closed = true + h.mutex.Unlock() + wg.Wait() +} + +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { + var token protocol.StatelessResetToken + if h.statelessResetHasher == nil { + // Return a random stateless reset token. + // This token will be sent in the server's transport parameters. + // By using a random token, an off-path attacker won't be able to disrupt the connection. + rand.Read(token[:]) + return token + } + h.statelessResetMutex.Lock() + h.statelessResetHasher.Write(connID.Bytes()) + copy(token[:], h.statelessResetHasher.Sum(nil)) + h.statelessResetHasher.Reset() + h.statelessResetMutex.Unlock() + return token +} diff --git a/vendor/github.com/quic-go/quic-go/packet_packer.go b/vendor/github.com/quic-go/quic-go/packet_packer.go new file mode 100644 index 00000000000..14befd460f5 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/packet_packer.go @@ -0,0 +1,968 @@ +package quic + +import ( + "errors" + "fmt" + "net" + "time" + + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" +) + +var errNothingToPack = errors.New("nothing to pack") + +type packer interface { + PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) + PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.VersionNumber) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError, protocol.VersionNumber) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError, protocol.VersionNumber) (*coalescedPacket, error) + + SetMaxPacketSize(protocol.ByteCount) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + + HandleTransportParameters(*wire.TransportParameters) + SetToken([]byte) +} + +type sealer interface { + handshake.LongHeaderSealer +} + +type payload struct { + frames []*ackhandler.Frame + ack *wire.AckFrame + length protocol.ByteCount +} + +type longHeaderPacket struct { + header *wire.ExtendedHeader + ack *wire.AckFrame + frames []*ackhandler.Frame + + length protocol.ByteCount + + isMTUProbePacket bool +} + +type shortHeaderPacket struct { + *ackhandler.Packet + // used for logging + DestConnID protocol.ConnectionID + Ack *wire.AckFrame + PacketNumberLen protocol.PacketNumberLen + KeyPhase protocol.KeyPhaseBit +} + +func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) } + +type coalescedPacket struct { + buffer *packetBuffer + longHdrPackets []*longHeaderPacket + shortHdrPacket *shortHeaderPacket +} + +func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { + //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + case protocol.PacketType0RTT: + return protocol.Encryption0RTT + default: + panic("can't determine encryption level") + } +} + +func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } + +func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { + largestAcked := protocol.InvalidPacketNumber + if p.ack != nil { + largestAcked = p.ack.LargestAcked() + } + encLevel := p.EncryptionLevel() + for i := range p.frames { + if p.frames[i].OnLost != nil { + continue + } + //nolint:exhaustive // Short header packets are handled separately. + switch encLevel { + case protocol.EncryptionInitial: + p.frames[i].OnLost = q.AddInitial + case protocol.EncryptionHandshake: + p.frames[i].OnLost = q.AddHandshake + case protocol.Encryption0RTT: + p.frames[i].OnLost = q.AddAppData + } + } + + ap := ackhandler.GetPacket() + ap.PacketNumber = p.header.PacketNumber + ap.LargestAcked = largestAcked + ap.Frames = p.frames + ap.Length = p.length + ap.EncryptionLevel = encLevel + ap.SendTime = now + ap.IsPathMTUProbePacket = p.isMTUProbePacket + return ap +} + +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if utils.IsIPv4(udpAddr.IP) { + maxSize = protocol.InitialPacketSizeIPv4 + } else { + maxSize = protocol.InitialPacketSizeIPv6 + } + } + return maxSize +} + +type packetNumberManager interface { + PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) + PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber +} + +type sealingManager interface { + GetInitialSealer() (handshake.LongHeaderSealer, error) + GetHandshakeSealer() (handshake.LongHeaderSealer, error) + Get0RTTSealer() (handshake.LongHeaderSealer, error) + Get1RTTSealer() (handshake.ShortHeaderSealer, error) +} + +type frameSource interface { + HasData() bool + AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) +} + +type ackFrameSource interface { + GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame +} + +type packetPacker struct { + srcConnID protocol.ConnectionID + getDestConnID func() protocol.ConnectionID + + perspective protocol.Perspective + cryptoSetup sealingManager + + initialStream cryptoStream + handshakeStream cryptoStream + + token []byte + + pnManager packetNumberManager + framer frameSource + acks ackFrameSource + datagramQueue *datagramQueue + retransmissionQueue *retransmissionQueue + + maxPacketSize protocol.ByteCount + numNonAckElicitingAcks int +} + +var _ packer = &packetPacker{} + +func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective) *packetPacker { + return &packetPacker{ + cryptoSetup: cryptoSetup, + getDestConnID: getDestConnID, + srcConnID: srcConnID, + initialStream: initialStream, + handshakeStream: handshakeStream, + retransmissionQueue: retransmissionQueue, + datagramQueue: datagramQueue, + perspective: perspective, + framer: framer, + acks: acks, + pnManager: packetNumberManager, + maxPacketSize: getMaxPacketSize(remoteAddr), + } +} + +// PackConnectionClose packs a packet that closes the connection with a transport error. +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, v protocol.VersionNumber) (*coalescedPacket, error) { + var reason string + // don't send details of crypto errors + if !e.ErrorCode.IsCryptoError() { + reason = e.ErrorMessage + } + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, v) +} + +// PackApplicationClose packs a packet that closes the connection with an application error. +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, v protocol.VersionNumber) (*coalescedPacket, error) { + return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, v) +} + +func (p *packetPacker) packConnectionClose( + isApplicationError bool, + errorCode uint64, + frameType uint64, + reason string, + v protocol.VersionNumber, +) (*coalescedPacket, error) { + var sealers [4]sealer + var hdrs [3]*wire.ExtendedHeader + var payloads [4]payload + var size protocol.ByteCount + var connID protocol.ConnectionID + var oneRTTPacketNumber protocol.PacketNumber + var oneRTTPacketNumberLen protocol.PacketNumberLen + var keyPhase protocol.KeyPhaseBit // only set for 1-RTT + var numLongHdrPackets uint8 + encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} + for i, encLevel := range encLevels { + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { + continue + } + ccf := &wire.ConnectionCloseFrame{ + IsApplicationError: isApplicationError, + ErrorCode: errorCode, + FrameType: frameType, + ReasonPhrase: reason, + } + // don't send application errors in Initial or Handshake packets + if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) { + ccf.IsApplicationError = false + ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode) + ccf.ReasonPhrase = "" + } + pl := payload{ + frames: []*ackhandler.Frame{{Frame: ccf}}, + length: ccf.Length(v), + } + + var sealer sealer + var err error + switch encLevel { + case protocol.EncryptionInitial: + sealer, err = p.cryptoSetup.GetInitialSealer() + case protocol.EncryptionHandshake: + sealer, err = p.cryptoSetup.GetHandshakeSealer() + case protocol.Encryption0RTT: + sealer, err = p.cryptoSetup.Get0RTTSealer() + case protocol.Encryption1RTT: + var s handshake.ShortHeaderSealer + s, err = p.cryptoSetup.Get1RTTSealer() + if err == nil { + keyPhase = s.KeyPhase() + } + sealer = s + } + if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped { + continue + } + if err != nil { + return nil, err + } + sealers[i] = sealer + var hdr *wire.ExtendedHeader + if encLevel == protocol.Encryption1RTT { + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, pl) + } else { + hdr = p.getLongHeader(encLevel, v) + hdrs[i] = hdr + size += p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) + numLongHdrPackets++ + } + payloads[i] = pl + } + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets), + } + for i, encLevel := range encLevels { + if sealers[i] == nil { + continue + } + var paddingLen protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + paddingLen = p.initialPaddingLen(payloads[i].frames, size) + } + if encLevel == protocol.Encryption1RTT { + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false, v) + if err != nil { + return nil, err + } + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: oneRTTPacketNumberLen, + KeyPhase: keyPhase, + } + } else { + longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) + } + } + return packet, nil +} + +// longHeaderPacketLength calculates the length of a serialized long header packet. +// It takes into account that packets that have a tiny payload need to be padded, +// such that len(payload) + packet number len >= 4 + AEAD overhead +func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(hdr.PacketNumberLen) + if pl.length < 4-pnLen { + paddingLen = 4 - pnLen - pl.length + } + return hdr.GetLength(v) + pl.length + paddingLen +} + +// shortHeaderPacketLength calculates the length of a serialized short header packet. +// It takes into account that packets that have a tiny payload need to be padded, +// such that len(payload) + packet number len >= 4 + AEAD overhead +func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, pl payload) protocol.ByteCount { + var paddingLen protocol.ByteCount + if pl.length < 4-protocol.ByteCount(pnLen) { + paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length + } + return wire.ShortHeaderLen(connID, pnLen) + pl.length + paddingLen +} + +// size is the expected size of the packet, if no padding was applied. +func (p *packetPacker) initialPaddingLen(frames []*ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { + // For the server, only ack-eliciting Initial packets need to be padded. + if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { + return 0 + } + if size >= p.maxPacketSize { + return 0 + } + return p.maxPacketSize - size +} + +// PackCoalescedPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) { + maxPacketSize := p.maxPacketSize + if p.perspective == protocol.PerspectiveClient { + maxPacketSize = protocol.MinInitialPacketSize + } + var ( + initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader + initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload + oneRTTPacketNumber protocol.PacketNumber + oneRTTPacketNumberLen protocol.PacketNumberLen + ) + // Try packing an Initial packet. + initialSealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + var size protocol.ByteCount + if initialSealer != nil { + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v) + if initialPayload.length > 0 { + size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) + } + } + + // Add a Handshake packet. + var handshakeSealer sealer + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if handshakeSealer != nil { + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v) + if handshakePayload.length > 0 { + s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead()) + size += s + } + } + } + + // Add a 0-RTT / 1-RTT packet. + var zeroRTTSealer sealer + var oneRTTSealer handshake.ShortHeaderSealer + var connID protocol.ConnectionID + var kp protocol.KeyPhaseBit + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if err == nil { // 1-RTT + kp = oneRTTSealer.KeyPhase() + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) + oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v) + if oneRTTPayload.length > 0 { + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) + } + } else if p.perspective == protocol.PerspectiveClient { // 0-RTT + var err error + zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if zeroRTTSealer != nil { + zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v) + if zeroRTTPayload.length > 0 { + size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead()) + } + } + } + } + + if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 { + return nil, nil + } + + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, 3), + } + if initialPayload.length > 0 { + padding := p.initialPaddingLen(initialPayload.frames, size) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + if handshakePayload.length > 0 { + cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + if zeroRTTPayload.length > 0 { + longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) + } else if oneRTTPayload.length > 0 { + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false, v) + if err != nil { + return nil, err + } + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: oneRTTPacketNumberLen, + KeyPhase: kp, + } + } + return packet, nil +} + +// PackPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { + sealer, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return shortHeaderPacket{}, nil, err + } + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + connID := p.getDestConnID() + hdrLen := wire.ShortHeaderLen(connID, pnLen) + pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true, v) + if pl.length == 0 { + return shortHeaderPacket{}, nil, errNothingToPack + } + kp := sealer.KeyPhase() + buffer := getPacketBuffer() + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false, v) + if err != nil { + return shortHeaderPacket{}, nil, err + } + return shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + }, buffer, nil +} + +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { + if onlyAck { + if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { + return p.getLongHeader(encLevel, v), payload{ + ack: ack, + length: ack.Length(v), + } + } + return nil, payload{} + } + + var s cryptoStream + var hasRetransmission bool + //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. + switch encLevel { + case protocol.EncryptionInitial: + s = p.initialStream + hasRetransmission = p.retransmissionQueue.HasInitialData() + case protocol.EncryptionHandshake: + s = p.handshakeStream + hasRetransmission = p.retransmissionQueue.HasHandshakeData() + } + + hasData := s.HasData() + var ack *wire.AckFrame + if ackAllowed { + ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) + } + if !hasData && !hasRetransmission && ack == nil { + // nothing to send + return nil, payload{} + } + + var pl payload + if ack != nil { + pl.ack = ack + pl.length = ack.Length(v) + maxPacketSize -= pl.length + } + hdr := p.getLongHeader(encLevel, v) + maxPacketSize -= hdr.GetLength(v) + if hasRetransmission { + for { + var f wire.Frame + //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s + switch encLevel { + case protocol.EncryptionInitial: + f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) + case protocol.EncryptionHandshake: + f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v) + } + if f == nil { + break + } + af := ackhandler.GetFrame() + af.Frame = f + pl.frames = append(pl.frames, af) + frameLen := f.Length(v) + pl.length += frameLen + maxPacketSize -= frameLen + } + } else if s.HasData() { + cf := s.PopCryptoFrame(maxPacketSize) + pl.frames = []*ackhandler.Frame{{Frame: cf}} + pl.length += cf.Length(v) + } + return hdr, pl +} + +func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { + if p.perspective != protocol.PerspectiveClient { + return nil, payload{} + } + + hdr := p.getLongHeader(protocol.Encryption0RTT, v) + maxPayloadSize := maxPacketSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead()) + return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v) +} + +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { + maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) + return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v) +} + +func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { + pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) + + // check if we have anything to send + if len(pl.frames) == 0 { + if pl.ack == nil { + return payload{} + } + // the packet only contains an ACK + if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { + ping := &wire.PingFrame{} + // don't retransmit the PING frame when it is lost + af := ackhandler.GetFrame() + af.Frame = ping + af.OnLost = func(wire.Frame) {} + pl.frames = append(pl.frames, af) + pl.length += ping.Length(v) + p.numNonAckElicitingAcks = 0 + } else { + p.numNonAckElicitingAcks++ + } + } else { + p.numNonAckElicitingAcks = 0 + } + return pl +} + +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { + if onlyAck { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { + return payload{ + ack: ack, + length: ack.Length(v), + } + } + return payload{} + } + + pl := payload{frames: make([]*ackhandler.Frame, 0, 1)} + + hasData := p.framer.HasData() + hasRetransmission := p.retransmissionQueue.HasAppData() + + var hasAck bool + if ackAllowed { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { + pl.ack = ack + pl.length += ack.Length(v) + hasAck = true + } + } + + if p.datagramQueue != nil { + if f := p.datagramQueue.Peek(); f != nil { + size := f.Length(v) + if size <= maxFrameSize-pl.length { + af := ackhandler.GetFrame() + af.Frame = f + // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. + af.OnLost = func(wire.Frame) {} + pl.frames = append(pl.frames, af) + pl.length += size + p.datagramQueue.Pop() + } + } + } + + if hasAck && !hasData && !hasRetransmission { + return pl + } + + if hasRetransmission { + for { + remainingLen := maxFrameSize - pl.length + if remainingLen < protocol.MinStreamFrameSize { + break + } + f := p.retransmissionQueue.GetAppDataFrame(remainingLen, v) + if f == nil { + break + } + af := ackhandler.GetFrame() + af.Frame = f + pl.frames = append(pl.frames, af) + pl.length += f.Length(v) + } + } + + if hasData { + var lengthAdded protocol.ByteCount + pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) + pl.length += lengthAdded + + pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length, v) + pl.length += lengthAdded + } + return pl +} + +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (*coalescedPacket, error) { + if encLevel == protocol.Encryption1RTT { + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + kp := s.KeyPhase() + connID := p.getDestConnID() + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, pnLen) + pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) + if pl.length == 0 { + return nil, nil + } + buffer := getPacketBuffer() + packet := &coalescedPacket{buffer: buffer} + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false, v) + if err != nil { + return nil, err + } + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + } + return packet, nil + } + + var hdr *wire.ExtendedHeader + var pl payload + var sealer handshake.LongHeaderSealer + //nolint:exhaustive // Probe packets are never sent for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) + case protocol.EncryptionHandshake: + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } + hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) + default: + panic("unknown encryption level") + } + + if pl.length == 0 { + return nil, nil + } + buffer := getPacketBuffer() + packet := &coalescedPacket{buffer: buffer} + size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) + var padding protocol.ByteCount + if encLevel == protocol.EncryptionInitial { + padding = p.initialPaddingLen(pl.frames, size) + } + + longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = []*longHeaderPacket{longHdrPacket} + return packet, nil +} + +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { + pl := payload{ + frames: []*ackhandler.Frame{&ping}, + length: ping.Length(v), + } + buffer := getPacketBuffer() + s, err := p.cryptoSetup.Get1RTTSealer() + if err != nil { + return shortHeaderPacket{}, nil, err + } + connID := p.getDestConnID() + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) + kp := s.KeyPhase() + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true, v) + if err != nil { + return shortHeaderPacket{}, nil, err + } + return shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + }, buffer, nil +} + +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) + hdr := &wire.ExtendedHeader{ + PacketNumber: pn, + PacketNumberLen: pnLen, + } + hdr.Version = v + hdr.SrcConnectionID = p.srcConnID + hdr.DestConnectionID = p.getDestConnID() + + //nolint:exhaustive // 1-RTT packets are not long header packets. + switch encLevel { + case protocol.EncryptionInitial: + hdr.Type = protocol.PacketTypeInitial + hdr.Token = p.token + case protocol.EncryptionHandshake: + hdr.Type = protocol.PacketTypeHandshake + case protocol.Encryption0RTT: + hdr.Type = protocol.PacketType0RTT + } + return hdr +} + +func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { + var paddingLen protocol.ByteCount + pnLen := protocol.ByteCount(header.PacketNumberLen) + if pl.length < 4-pnLen { + paddingLen = 4 - pnLen - pl.length + } + paddingLen += padding + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen + + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] + raw, err := header.Append(raw, v) + if err != nil { + return nil, err + } + payloadOffset := protocol.ByteCount(len(raw)) + + pn := p.pnManager.PopPacketNumber(encLevel) + if pn != header.PacketNumber { + return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + } + + raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) + if err != nil { + return nil, err + } + raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) + buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + + return &longHeaderPacket{ + header: header, + ack: pl.ack, + frames: pl.frames, + length: protocol.ByteCount(len(raw)), + }, nil +} + +func (p *packetPacker) appendShortHeaderPacket( + buffer *packetBuffer, + connID protocol.ConnectionID, + pn protocol.PacketNumber, + pnLen protocol.PacketNumberLen, + kp protocol.KeyPhaseBit, + pl payload, + padding protocol.ByteCount, + sealer sealer, + isMTUProbePacket bool, + v protocol.VersionNumber, +) (*ackhandler.Packet, *wire.AckFrame, error) { + var paddingLen protocol.ByteCount + if pl.length < 4-protocol.ByteCount(pnLen) { + paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length + } + paddingLen += padding + + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] + raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) + if err != nil { + return nil, nil, err + } + payloadOffset := protocol.ByteCount(len(raw)) + + if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { + return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + } + + raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) + if err != nil { + return nil, nil, err + } + if !isMTUProbePacket { + if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { + return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + } + } + raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) + buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + + // create the ackhandler.Packet + largestAcked := protocol.InvalidPacketNumber + if pl.ack != nil { + largestAcked = pl.ack.LargestAcked() + } + for i := range pl.frames { + if pl.frames[i].OnLost != nil { + continue + } + pl.frames[i].OnLost = p.retransmissionQueue.AddAppData + } + + ap := ackhandler.GetPacket() + ap.PacketNumber = pn + ap.LargestAcked = largestAcked + ap.Frames = pl.frames + ap.Length = protocol.ByteCount(len(raw)) + ap.EncryptionLevel = protocol.Encryption1RTT + ap.SendTime = time.Now() + ap.IsPathMTUProbePacket = isMTUProbePacket + + return ap, pl.ack, nil +} + +func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { + payloadOffset := len(raw) + if pl.ack != nil { + var err error + raw, err = pl.ack.Append(raw, v) + if err != nil { + return nil, err + } + } + if paddingLen > 0 { + raw = append(raw, make([]byte, paddingLen)...) + } + for _, frame := range pl.frames { + var err error + raw, err = frame.Append(raw, v) + if err != nil { + return nil, err + } + } + + if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length { + return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize) + } + return raw, nil +} + +func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte { + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset]) + raw = raw[:len(raw)+sealer.Overhead()] + // apply header protection + pnOffset := payloadOffset - pnLen + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) + return raw +} + +func (p *packetPacker) SetToken(token []byte) { + p.token = token +} + +// When a higher MTU is discovered, use it. +func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { + p.maxPacketSize = s +} + +// If the peer sets a max_packet_size that's smaller than the size we're currently using, +// we need to reduce the size of packets we send. +func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { + if params.MaxUDPPayloadSize != 0 { + p.maxPacketSize = utils.Min(p.maxPacketSize, params.MaxUDPPayloadSize) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go b/vendor/github.com/quic-go/quic-go/packet_unpacker.go similarity index 55% rename from vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go rename to vendor/github.com/quic-go/quic-go/packet_unpacker.go index f70d8d075a0..103524c7dd6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go +++ b/vendor/github.com/quic-go/quic-go/packet_unpacker.go @@ -5,9 +5,10 @@ import ( "fmt" "time" - "github.com/lucas-clemente/quic-go/internal/handshake" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/wire" ) type headerDecryptor interface { @@ -27,7 +28,6 @@ func (e *headerParseError) Error() string { } type unpackedPacket struct { - packetNumber protocol.PacketNumber // the decoded packet number hdr *wire.ExtendedHeader encryptionLevel protocol.EncryptionLevel data []byte @@ -37,22 +37,23 @@ type unpackedPacket struct { type packetUnpacker struct { cs handshake.CryptoSetup - version protocol.VersionNumber + shortHdrConnIDLen int } var _ unpacker = &packetUnpacker{} -func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) unpacker { +func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker { return &packetUnpacker{ - cs: cs, - version: version, + cs: cs, + shortHdrConnIDLen: shortHdrConnIDLen, } } +// UnpackLongHeader unpacks a Long Header packet. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If any other error occurred when parsing the header, the error is of type headerParseError. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. -func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { +func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader var decrypted []byte @@ -64,7 +65,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte if err != nil { return nil, err } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) if err != nil { return nil, err } @@ -74,7 +75,7 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte if err != nil { return nil, err } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) if err != nil { return nil, err } @@ -84,35 +85,48 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte if err != nil { return nil, err } - extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data) + extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) if err != nil { return nil, err } default: - if hdr.IsLongHeader { - return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) - } - encLevel = protocol.Encryption1RTT - opener, err := u.cs.Get1RTTOpener() - if err != nil { - return nil, err - } - extHdr, decrypted, err = u.unpackShortHeaderPacket(opener, hdr, rcvTime, data) - if err != nil { - return nil, err + return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) + } + + if len(decrypted) == 0 { + return nil, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", } } return &unpackedPacket{ hdr: extHdr, - packetNumber: extHdr.PacketNumber, encryptionLevel: encLevel, data: decrypted, }, nil } -func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) +func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { + opener, err := u.cs.Get1RTTOpener() + if err != nil { + return 0, 0, 0, nil, err + } + pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) + if err != nil { + return 0, 0, 0, nil, err + } + if len(decrypted) == 0 { + return 0, 0, 0, nil, &qerr.TransportError{ + ErrorCode: qerr.ProtocolViolation, + ErrorMessage: "empty packet", + } + } + return pn, pnLen, kp, decrypted, nil +} + +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) { + extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. @@ -131,41 +145,57 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene return extHdr, decrypted, nil } -func (u *packetUnpacker) unpackShortHeaderPacket( - opener handshake.ShortHeaderOpener, - hdr *wire.Header, - rcvTime time.Time, - data []byte, -) (*wire.ExtendedHeader, []byte, error) { - extHdr, parseErr := u.unpackHeader(opener, hdr, data) +func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { + l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, parseErr + return 0, 0, 0, nil, &headerParseError{parseErr} } - extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) - extHdrLen := extHdr.ParsedLen() - decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen]) + pn = opener.DecodePacketNumber(pn, pnLen) + decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l]) if err != nil { - return nil, nil, err + return 0, 0, 0, nil, err } - if parseErr != nil { - return nil, nil, parseErr + return pn, pnLen, kp, decrypted, parseErr +} + +func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) { + hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen + if len(data) < hdrLen+4+16 { + return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) } - return extHdr, decrypted, nil + origPNBytes := make([]byte, 4) + copy(origPNBytes, data[hdrLen:hdrLen+4]) + // 2. decrypt the header, assuming a 4 byte packet number + hd.DecryptHeader( + data[hdrLen+4:hdrLen+4+16], + &data[0], + data[hdrLen:hdrLen+4], + ) + // 3. parse the header (and learn the actual length of the packet number) + l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) + if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { + return l, pn, pnLen, kp, parseErr + } + // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier + if pnLen != protocol.PacketNumberLen4 { + copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):]) + } + return l, pn, pnLen, kp, parseErr } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. -func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { - extHdr, err := unpackHeader(hd, hdr, data, u.version) +func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { + extHdr, err := unpackLongHeader(hd, hdr, data, v) if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} } return extHdr, err } -func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) hdrLen := hdr.ParsedLen() @@ -184,7 +214,7 @@ func unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte, version pro data[hdrLen:hdrLen+4], ) // 3. parse the header (and learn the actual length of the packet number) - extHdr, parseErr := hdr.ParseExtended(r, version) + extHdr, parseErr := hdr.ParseExtended(r, v) if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { return nil, parseErr } diff --git a/vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go b/vendor/github.com/quic-go/quic-go/quicvarint/io.go similarity index 100% rename from vendor/github.com/lucas-clemente/quic-go/quicvarint/io.go rename to vendor/github.com/quic-go/quic-go/quicvarint/io.go diff --git a/vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go b/vendor/github.com/quic-go/quic-go/quicvarint/varint.go similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go rename to vendor/github.com/quic-go/quic-go/quicvarint/varint.go index 66e0d39e659..cbebfe61fa0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/quicvarint/varint.go +++ b/vendor/github.com/quic-go/quic-go/quicvarint/varint.go @@ -4,7 +4,7 @@ import ( "fmt" "io" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/protocol" ) // taken from the QUIC draft @@ -71,6 +71,7 @@ func Read(r io.ByteReader) (uint64, error) { } // Write writes i in the QUIC varint format to w. +// Deprecated: use Append instead. func Write(w Writer, i uint64) { if i <= maxVarInt1 { w.WriteByte(uint8(i)) @@ -88,32 +89,52 @@ func Write(w Writer, i uint64) { } } -// WriteWithLen writes i in the QUIC varint format with the desired length to w. -func WriteWithLen(w Writer, i uint64, length protocol.ByteCount) { +// Append appends i in the QUIC varint format. +func Append(b []byte, i uint64) []byte { + if i <= maxVarInt1 { + return append(b, uint8(i)) + } + if i <= maxVarInt2 { + return append(b, []byte{uint8(i>>8) | 0x40, uint8(i)}...) + } + if i <= maxVarInt4 { + return append(b, []byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}...) + } + if i <= maxVarInt8 { + return append(b, []byte{ + uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), + uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), + }...) + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} + +// AppendWithLen append i in the QUIC varint format with the desired length. +func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } l := Len(i) if l == length { - Write(w, i) - return + return Append(b, i) } if l > length { panic(fmt.Sprintf("cannot encode %d in %d bytes", i, length)) } if length == 2 { - w.WriteByte(0b01000000) + b = append(b, 0b01000000) } else if length == 4 { - w.WriteByte(0b10000000) + b = append(b, 0b10000000) } else if length == 8 { - w.WriteByte(0b11000000) + b = append(b, 0b11000000) } for j := protocol.ByteCount(1); j < length-l; j++ { - w.WriteByte(0) + b = append(b, 0) } for j := protocol.ByteCount(0); j < l; j++ { - w.WriteByte(uint8(i >> (8 * (l - 1 - j)))) + b = append(b, uint8(i>>(8*(l-1-j)))) } + return b } // Len determines the number of bytes that will be needed to write the number i. diff --git a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go b/vendor/github.com/quic-go/quic-go/receive_stream.go similarity index 86% rename from vendor/github.com/lucas-clemente/quic-go/receive_stream.go rename to vendor/github.com/quic-go/quic-go/receive_stream.go index ae6a449b575..0a7e941675f 100644 --- a/vendor/github.com/lucas-clemente/quic-go/receive_stream.go +++ b/vendor/github.com/quic-go/quic-go/receive_stream.go @@ -6,11 +6,11 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type receiveStreamI interface { @@ -34,24 +34,19 @@ type receiveStream struct { currentFrame []byte currentFrameDone func() - currentFrameIsLast bool // is the currentFrame the last frame on this stream readPosInFrame int + currentFrameIsLast bool // is the currentFrame the last frame on this stream + finRead bool // set once we read a frame with a Fin closeForShutdownErr error cancelReadErr error resetRemotelyErr *StreamError - closedForShutdown bool // set when CloseForShutdown() is called - finRead bool // set once we read a frame with a Fin - canceledRead bool // set when CancelRead() is called - resetRemotely bool // set when HandleResetStreamFrame() is called - readChan chan struct{} readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline time.Time flowController flowcontrol.StreamFlowController - version protocol.VersionNumber } var ( @@ -63,7 +58,6 @@ func newReceiveStream( streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, ) *receiveStream { return &receiveStream{ streamID: streamID, @@ -73,7 +67,6 @@ func newReceiveStream( readChan: make(chan struct{}, 1), readOnce: make(chan struct{}, 1), finalOffset: protocol.MaxByteCount, - version: version, } } @@ -103,13 +96,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err if s.finRead { return false, 0, io.EOF } - if s.canceledRead { + if s.cancelReadErr != nil { return false, 0, s.cancelReadErr } - if s.resetRemotely { + if s.resetRemotelyErr != nil { return false, 0, s.resetRemotelyErr } - if s.closedForShutdown { + if s.closeForShutdownErr != nil { return false, 0, s.closeForShutdownErr } @@ -125,13 +118,13 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err for { // Stop waiting on errors - if s.closedForShutdown { + if s.closeForShutdownErr != nil { return false, bytesRead, s.closeForShutdownErr } - if s.canceledRead { + if s.cancelReadErr != nil { return false, bytesRead, s.cancelReadErr } - if s.resetRemotely { + if s.resetRemotelyErr != nil { return false, bytesRead, s.resetRemotelyErr } @@ -178,8 +171,9 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err s.readPosInFrame += m bytesRead += m - // when a RESET_STREAM was received, the was already informed about the final byteOffset for this stream - if !s.resetRemotely { + // when a RESET_STREAM was received, the flow controller was already + // informed about the final byteOffset for this stream + if s.resetRemotelyErr == nil { s.flowController.AddBytesRead(protocol.ByteCount(m)) } @@ -214,11 +208,10 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) { } func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ { - if s.finRead || s.canceledRead || s.resetRemotely { + if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil { return false } - s.canceledRead = true - s.cancelReadErr = fmt.Errorf("Read on stream %d canceled with error code %d", s.streamID, errorCode) + s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false} s.signalRead() s.sender.queueControlFrame(&wire.StopSendingFrame{ StreamID: s.streamID, @@ -250,7 +243,7 @@ func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount s.finalOffset = maxOffset } - if s.canceledRead { + if s.cancelReadErr != nil { return newlyRcvdFinalOffset, nil } if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil { @@ -273,7 +266,7 @@ func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) err } func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) { - if s.closedForShutdown { + if s.closeForShutdownErr != nil { return false, nil } if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil { @@ -283,13 +276,13 @@ func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) s.finalOffset = frame.FinalSize // ignore duplicate RESET_STREAM frames for this stream (after checking their final offset) - if s.resetRemotely { + if s.resetRemotelyErr != nil { return false, nil } - s.resetRemotely = true s.resetRemotelyErr = &StreamError{ StreamID: s.streamID, ErrorCode: frame.ErrorCode, + Remote: true, } s.signalRead() return newlyRcvdFinalOffset, nil @@ -312,7 +305,6 @@ func (s *receiveStream) SetReadDeadline(t time.Time) error { // The peer will NOT be informed about this: the stream is closed without sending a FIN or RESET. func (s *receiveStream) closeForShutdown(err error) { s.mutex.Lock() - s.closedForShutdown = true s.closeForShutdownErr = err s.mutex.Unlock() s.signalRead() diff --git a/vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go b/vendor/github.com/quic-go/quic-go/retransmission_queue.go similarity index 83% rename from vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go rename to vendor/github.com/quic-go/quic-go/retransmission_queue.go index 0cfbbc4de41..2ce0b8931a4 100644 --- a/vendor/github.com/lucas-clemente/quic-go/retransmission_queue.go +++ b/vendor/github.com/quic-go/quic-go/retransmission_queue.go @@ -3,8 +3,8 @@ package quic import ( "fmt" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) type retransmissionQueue struct { @@ -15,12 +15,10 @@ type retransmissionQueue struct { handshakeCryptoData []*wire.CryptoFrame appData []wire.Frame - - version protocol.VersionNumber } -func newRetransmissionQueue(ver protocol.VersionNumber) *retransmissionQueue { - return &retransmissionQueue{version: ver} +func newRetransmissionQueue() *retransmissionQueue { + return &retransmissionQueue{} } func (q *retransmissionQueue) AddInitial(f wire.Frame) { @@ -58,10 +56,10 @@ func (q *retransmissionQueue) AddAppData(f wire.Frame) { q.appData = append(q.appData, f) } -func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.initialCryptoData) > 0 { f := q.initialCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits q.initialCryptoData = q.initialCryptoData[1:] return f @@ -74,17 +72,17 @@ func (q *retransmissionQueue) GetInitialFrame(maxLen protocol.ByteCount) wire.Fr return nil } f := q.initial[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.initial = q.initial[1:] return f } -func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.handshakeCryptoData) > 0 { f := q.handshakeCryptoData[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, q.version) + newFrame, needsSplit := f.MaybeSplitOffFrame(maxLen, v) if newFrame == nil && !needsSplit { // the whole frame fits q.handshakeCryptoData = q.handshakeCryptoData[1:] return f @@ -97,19 +95,19 @@ func (q *retransmissionQueue) GetHandshakeFrame(maxLen protocol.ByteCount) wire. return nil } f := q.handshake[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.handshake = q.handshake[1:] return f } -func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount) wire.Frame { +func (q *retransmissionQueue) GetAppDataFrame(maxLen protocol.ByteCount, v protocol.VersionNumber) wire.Frame { if len(q.appData) == 0 { return nil } f := q.appData[0] - if f.Length(q.version) > maxLen { + if f.Length(v) > maxLen { return nil } q.appData = q.appData[1:] diff --git a/vendor/github.com/lucas-clemente/quic-go/send_conn.go b/vendor/github.com/quic-go/quic-go/send_conn.go similarity index 66% rename from vendor/github.com/lucas-clemente/quic-go/send_conn.go rename to vendor/github.com/quic-go/quic-go/send_conn.go index c53ebdfab1c..0ac270378e3 100644 --- a/vendor/github.com/lucas-clemente/quic-go/send_conn.go +++ b/vendor/github.com/quic-go/quic-go/send_conn.go @@ -22,7 +22,7 @@ type sconn struct { var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) *sconn { return &sconn{ rawConn: c, remoteAddr: remote, @@ -51,24 +51,3 @@ func (c *sconn) LocalAddr() net.Addr { } return addr } - -type spconn struct { - net.PacketConn - - remoteAddr net.Addr -} - -var _ sendConn = &spconn{} - -func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { - return &spconn{PacketConn: c, remoteAddr: remote} -} - -func (c *spconn) Write(p []byte) error { - _, err := c.WriteTo(p, c.remoteAddr) - return err -} - -func (c *spconn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/vendor/github.com/lucas-clemente/quic-go/send_queue.go b/vendor/github.com/quic-go/quic-go/send_queue.go similarity index 93% rename from vendor/github.com/lucas-clemente/quic-go/send_queue.go rename to vendor/github.com/quic-go/quic-go/send_queue.go index 1fc8c1bf893..9eafcd374b0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/send_queue.go +++ b/vendor/github.com/quic-go/quic-go/send_queue.go @@ -36,6 +36,13 @@ func newSendQueue(conn sendConn) sender { func (h *sendQueue) Send(p *packetBuffer) { select { case h.queue <- p: + // clear available channel if we've reached capacity + if len(h.queue) == sendQueueCapacity { + select { + case <-h.available: + default: + } + } case <-h.runStopped: default: panic("sendQueue.Send would have blocked") diff --git a/vendor/github.com/lucas-clemente/quic-go/send_stream.go b/vendor/github.com/quic-go/quic-go/send_stream.go similarity index 81% rename from vendor/github.com/lucas-clemente/quic-go/send_stream.go rename to vendor/github.com/quic-go/quic-go/send_stream.go index b23df00b377..cebe30ef08e 100644 --- a/vendor/github.com/lucas-clemente/quic-go/send_stream.go +++ b/vendor/github.com/quic-go/quic-go/send_stream.go @@ -6,19 +6,19 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" ) type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) hasData() bool - popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) closeForShutdown(error) updateSendWindow(protocol.ByteCount) } @@ -40,11 +40,9 @@ type sendStream struct { cancelWriteErr error closeForShutdownErr error - closedForShutdown bool // set when CloseForShutdown() is called - finishedWriting bool // set once Close() is called - canceledWrite bool // set when CancelWrite() is called, or a STOP_SENDING frame is received - finSent bool // set when a STREAM_FRAME with FIN bit has been sent - completed bool // set when this stream has been reported to the streamSender as completed + finishedWriting bool // set once Close() is called + finSent bool // set when a STREAM_FRAME with FIN bit has been sent + completed bool // set when this stream has been reported to the streamSender as completed dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out nextFrame *wire.StreamFrame @@ -54,8 +52,6 @@ type sendStream struct { deadline time.Time flowController flowcontrol.StreamFlowController - - version protocol.VersionNumber } var ( @@ -67,7 +63,6 @@ func newSendStream( streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, ) *sendStream { s := &sendStream{ streamID: streamID, @@ -75,7 +70,6 @@ func newSendStream( flowController: flowController, writeChan: make(chan struct{}, 1), writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write - version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) return s @@ -98,7 +92,7 @@ func (s *sendStream) Write(p []byte) (int, error) { if s.finishedWriting { return 0, fmt.Errorf("write on closed stream %d", s.streamID) } - if s.canceledWrite { + if s.cancelWriteErr != nil { return 0, s.cancelWriteErr } if s.closeForShutdownErr != nil { @@ -122,7 +116,7 @@ func (s *sendStream) Write(p []byte) (int, error) { var copied bool var deadline time.Time // As soon as dataForWriting becomes smaller than a certain size x, we copy all the data to a STREAM frame (s.nextFrame), - // which can the be popped the next time we assemble a packet. + // which can then be popped the next time we assemble a packet. // This allows us to return Write() when all data but x bytes have been sent out. // When the user now calls Close(), this is much more likely to happen before we popped that last STREAM frame, // allowing us to set the FIN bit on that frame (instead of sending an empty STREAM frame with FIN). @@ -157,7 +151,7 @@ func (s *sendStream) Write(p []byte) (int, error) { } deadlineTimer.Reset(deadline) } - if s.dataForWriting == nil || s.canceledWrite || s.closedForShutdown { + if s.dataForWriting == nil || s.cancelWriteErr != nil || s.closeForShutdownErr != nil { break } } @@ -204,9 +198,9 @@ func (s *sendStream) canBufferStreamFrame() bool { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool /* has more data to send */) { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool /* has more data to send */) { s.mutex.Lock() - f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes) + f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { s.numOutstandingFrames++ } @@ -215,16 +209,20 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Fr if f == nil { return nil, hasMoreData } - return &ackhandler.Frame{Frame: f, OnLost: s.queueRetransmission, OnAcked: s.frameAcked}, hasMoreData + af := ackhandler.GetFrame() + af.Frame = f + af.OnLost = s.queueRetransmission + af.OnAcked = s.frameAcked + return af, hasMoreData } -func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more data to send */) { - if s.canceledWrite || s.closeForShutdownErr != nil { +func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { + if s.cancelWriteErr != nil || s.closeForShutdownErr != nil { return nil, false } if len(s.retransmissionQueue) > 0 { - f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes) + f, hasMoreRetransmissions := s.maybeGetRetransmission(maxBytes, v) if f != nil || hasMoreRetransmissions { if f == nil { return nil, true @@ -260,7 +258,7 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return nil, true } - f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow) + f, hasMoreData := s.popNewStreamFrame(maxBytes, sendWindow, v) if dataLen := f.DataLen(); dataLen > 0 { s.writeOffset += f.DataLen() s.flowController.AddBytesSent(f.DataLen()) @@ -272,12 +270,12 @@ func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCoun return f, hasMoreData } -func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) (*wire.StreamFrame, bool) { +func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool) { if s.nextFrame != nil { nextFrame := s.nextFrame s.nextFrame = nil - maxDataLen := utils.MinByteCount(sendWindow, nextFrame.MaxDataLen(maxBytes, s.version)) + maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, v)) if nextFrame.DataLen() > maxDataLen { s.nextFrame = wire.GetStreamFrame() s.nextFrame.StreamID = s.streamID @@ -299,7 +297,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) f.DataLenPresent = true f.Data = f.Data[:0] - hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow) + hasMoreData := s.popNewStreamFrameWithoutBuffer(f, maxBytes, sendWindow, v) if len(f.Data) == 0 && !f.Fin { f.PutBack() return nil, hasMoreData @@ -307,19 +305,19 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount) return f, hasMoreData } -func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount) bool { - maxDataLen := f.MaxDataLen(maxBytes, s.version) +func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxBytes, sendWindow protocol.ByteCount, v protocol.VersionNumber) bool { + maxDataLen := f.MaxDataLen(maxBytes, v) if maxDataLen == 0 { // a STREAM frame must have at least one byte of data return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } - s.getDataForWriting(f, utils.MinByteCount(maxDataLen, sendWindow)) + s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow)) return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting } -func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount) (*wire.StreamFrame, bool /* has more retransmissions */) { +func (s *sendStream) maybeGetRetransmission(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more retransmissions */) { f := s.retransmissionQueue[0] - newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, s.version) + newFrame, needsSplit := f.MaybeSplitOffFrame(maxBytes, v) if needsSplit { return newFrame, true } @@ -354,7 +352,7 @@ func (s *sendStream) frameAcked(f wire.Frame) { f.(*wire.StreamFrame).PutBack() s.mutex.Lock() - if s.canceledWrite { + if s.cancelWriteErr != nil { s.mutex.Unlock() return } @@ -371,7 +369,7 @@ func (s *sendStream) frameAcked(f wire.Frame) { } func (s *sendStream) isNewlyCompleted() bool { - completed := (s.finSent || s.canceledWrite) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 + completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 if completed && !s.completed { s.completed = true return true @@ -383,7 +381,7 @@ func (s *sendStream) queueRetransmission(f wire.Frame) { sf := f.(*wire.StreamFrame) sf.DataLenPresent = true s.mutex.Lock() - if s.canceledWrite { + if s.cancelWriteErr != nil { s.mutex.Unlock() return } @@ -399,11 +397,11 @@ func (s *sendStream) queueRetransmission(f wire.Frame) { func (s *sendStream) Close() error { s.mutex.Lock() - if s.closedForShutdown { + if s.closeForShutdownErr != nil { s.mutex.Unlock() return nil } - if s.canceledWrite { + if s.cancelWriteErr != nil { s.mutex.Unlock() return fmt.Errorf("close called for canceled stream %d", s.streamID) } @@ -416,19 +414,18 @@ func (s *sendStream) Close() error { } func (s *sendStream) CancelWrite(errorCode StreamErrorCode) { - s.cancelWriteImpl(errorCode, fmt.Errorf("Write on stream %d canceled with error code %d", s.streamID, errorCode)) + s.cancelWriteImpl(errorCode, false) } // must be called after locking the mutex -func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, writeErr error) { +func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) { s.mutex.Lock() - if s.canceledWrite { + if s.cancelWriteErr != nil { s.mutex.Unlock() return } s.ctxCancel() - s.canceledWrite = true - s.cancelWriteErr = writeErr + s.cancelWriteErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: remote} s.numOutstandingFrames = 0 s.retransmissionQueue = nil newlyCompleted := s.isNewlyCompleted() @@ -457,10 +454,7 @@ func (s *sendStream) updateSendWindow(limit protocol.ByteCount) { } func (s *sendStream) handleStopSendingFrame(frame *wire.StopSendingFrame) { - s.cancelWriteImpl(frame.ErrorCode, &StreamError{ - StreamID: s.streamID, - ErrorCode: frame.ErrorCode, - }) + s.cancelWriteImpl(frame.ErrorCode, true) } func (s *sendStream) Context() context.Context { @@ -481,7 +475,6 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error { func (s *sendStream) closeForShutdown(err error) { s.mutex.Lock() s.ctxCancel() - s.closedForShutdown = true s.closeForShutdownErr = err s.mutex.Unlock() s.signalWrite() diff --git a/vendor/github.com/quic-go/quic-go/server.go b/vendor/github.com/quic-go/quic-go/server.go new file mode 100644 index 00000000000..f5b04549e83 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/server.go @@ -0,0 +1,820 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" +) + +// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. +var ErrServerClosed = errors.New("quic: server closed") + +// packetHandler handles packets +type packetHandler interface { + handlePacket(*receivedPacket) + shutdown() + destroy(error) + getPerspective() protocol.Perspective +} + +type packetHandlerManager interface { + Get(protocol.ConnectionID) (packetHandler, bool) + GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool + Close(error) + CloseServer() + connRunner +} + +type quicConn interface { + EarlyConnection + earlyConnReady() <-chan struct{} + handlePacket(*receivedPacket) + GetVersion() protocol.VersionNumber + getPerspective() protocol.Perspective + run() error + destroy(error) + shutdown() +} + +type zeroRTTQueue struct { + packets []*receivedPacket + expiration time.Time +} + +// A Listener of QUIC +type baseServer struct { + mutex sync.Mutex + + acceptEarlyConns bool + + tlsConf *tls.Config + config *Config + + conn rawConn + + tokenGenerator *handshake.TokenGenerator + + connIDGenerator ConnectionIDGenerator + connHandler packetHandlerManager + onClose func() + + receivedPackets chan *receivedPacket + + nextZeroRTTCleanup time.Time + zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true + + // set as a member, so they can be set in the tests + newConn func( + sendConn, + connRunner, + protocol.ConnectionID, /* original dest connection ID */ + *protocol.ConnectionID, /* retry src connection ID */ + protocol.ConnectionID, /* client dest connection ID */ + protocol.ConnectionID, /* destination connection ID */ + protocol.ConnectionID, /* source connection ID */ + ConnectionIDGenerator, + protocol.StatelessResetToken, + *Config, + *tls.Config, + *handshake.TokenGenerator, + bool, /* client address validated by an address validation token */ + logging.ConnectionTracer, + uint64, + utils.Logger, + protocol.VersionNumber, + ) quicConn + + serverError error + errorChan chan struct{} + closed bool + running chan struct{} // closed as soon as run() returns + + connQueue chan quicConn + connQueueLen int32 // to be used as an atomic + + tracer logging.Tracer + + logger utils.Logger +} + +// A Listener listens for incoming QUIC connections. +// It returns connections once the handshake has completed. +type Listener struct { + baseServer *baseServer +} + +// Accept returns new connections. It should be called in a loop. +func (l *Listener) Accept(ctx context.Context) (Connection, error) { + return l.baseServer.Accept(ctx) +} + +// Close the server. All active connections will be closed. +func (l *Listener) Close() error { + return l.baseServer.Close() +} + +// Addr returns the local network address that the server is listening on. +func (l *Listener) Addr() net.Addr { + return l.baseServer.Addr() +} + +// An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes. +// For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data. +// This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified. +// For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the +// 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the +// client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker. +type EarlyListener struct { + baseServer *baseServer +} + +// Accept returns a new connections. It should be called in a loop. +func (l *EarlyListener) Accept(ctx context.Context) (EarlyConnection, error) { + return l.baseServer.accept(ctx) +} + +// Close the server. All active connections will be closed. +func (l *EarlyListener) Close() error { + return l.baseServer.Close() +} + +// Addr returns the local network addr that the server is listening on. +func (l *EarlyListener) Addr() net.Addr { + return l.baseServer.Addr() +} + +// ListenAddr creates a QUIC server listening on a given address. +// The tls.Config must not be nil and must contain a certificate configuration. +// The quic.Config may be nil, in that case the default values will be used. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { + conn, err := listenUDP(addr) + if err != nil { + return nil, err + } + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).Listen(tlsConf, config) +} + +// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { + conn, err := listenUDP(addr) + if err != nil { + return nil, err + } + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).ListenEarly(tlsConf, config) +} + +func listenUDP(addr string) (*net.UDPConn, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return net.ListenUDP("udp", udpAddr) +} + +// Listen listens for QUIC connections on a given net.PacketConn. If the +// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn +// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP +// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write +// packets. A single net.PacketConn only be used for a single call to Listen. +// The PacketConn can be used for simultaneous calls to Dial. QUIC connection +// IDs are used for demultiplexing the different connections. +// The tls.Config must not be nil and must contain a certificate configuration. +// Furthermore, it must define an application control (using NextProtos). +// The quic.Config may be nil, in that case the default values will be used. +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.Listen(tlsConf, config) +} + +// ListenEarly works like Listen, but it returns connections before the handshake completes. +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.ListenEarly(tlsConf, config) +} + +func newServer( + conn rawConn, + connHandler packetHandlerManager, + connIDGenerator ConnectionIDGenerator, + tlsConf *tls.Config, + config *Config, + tracer logging.Tracer, + onClose func(), + acceptEarly bool, +) (*baseServer, error) { + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) + if err != nil { + return nil, err + } + s := &baseServer{ + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + onClose: onClose, + } + if acceptEarly { + s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} + } + go s.run() + s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + return s, nil +} + +func (s *baseServer) run() { + defer close(s.running) + for { + select { + case <-s.errorChan: + return + default: + } + select { + case <-s.errorChan: + return + case p := <-s.receivedPackets: + if bufferStillInUse := s.handlePacketImpl(p); !bufferStillInUse { + p.buffer.Release() + } + } + } +} + +// Accept returns connections that already completed the handshake. +// It is only valid if acceptEarlyConns is false. +func (s *baseServer) Accept(ctx context.Context) (Connection, error) { + return s.accept(ctx) +} + +func (s *baseServer) accept(ctx context.Context) (quicConn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case conn := <-s.connQueue: + atomic.AddInt32(&s.connQueueLen, -1) + return conn, nil + case <-s.errorChan: + return nil, s.serverError + } +} + +// Close the server +func (s *baseServer) Close() error { + s.mutex.Lock() + if s.closed { + s.mutex.Unlock() + return nil + } + if s.serverError == nil { + s.serverError = ErrServerClosed + } + s.closed = true + close(s.errorChan) + s.mutex.Unlock() + + <-s.running + s.onClose() + return nil +} + +func (s *baseServer) setCloseError(e error) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.closed { + return + } + s.closed = true + s.serverError = e + close(s.errorChan) +} + +// Addr returns the server's network address +func (s *baseServer) Addr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *baseServer) handlePacket(p *receivedPacket) { + select { + case s.receivedPackets <- p: + default: + s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { + if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { + defer s.cleanupZeroRTTQueues(p.rcvTime) + } + + if wire.IsVersionNegotiationPacket(p.data) { + s.logger.Debugf("Dropping Version Negotiation packet.") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + // Short header packets should never end up here in the first place + if !wire.IsLongHeaderPacket(p.data[0]) { + panic(fmt.Sprintf("misrouted packet: %#v", p.data)) + } + v, err := wire.ParseVersion(p.data) + // send a Version Negotiation Packet if the client is speaking a different protocol version + if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { + if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) + if err != nil { // should never happen + s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + if !s.config.DisableVersionNegotiationPackets { + go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB(), v) + } + return false + } + + if wire.Is0RTTPacket(p.data) { + if !s.acceptEarlyConns { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + return s.handle0RTTPacket(p) + } + + // If we're creating a new connection, the packet will be passed to the connection. + // The header will then be parsed again. + hdr, _, _, err := wire.ParsePacket(p.data) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing packet: %s", err) + return false + } + if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + + if hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's little point in sending a Stateless Reset, since the client + // might not have received the token yet. + s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + + s.logger.Debugf("<- Received Initial packet.") + + if err := s.handleInitialImpl(p, hdr); err != nil { + s.logger.Errorf("Error occurred handling initial packet: %s", err) + } + // Don't put the packet buffer back. + // handleInitialImpl deals with the buffer. + return true +} + +func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { + connID, err := wire.ParseConnectionID(p.data, 0) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) + } + return false + } + + // check again if we might have a connection now + if handler, ok := s.connHandler.Get(connID); ok { + handler.handlePacket(p) + return true + } + + if q, ok := s.zeroRTTQueues[connID]; ok { + if len(q.packets) >= protocol.Max0RTTQueueLen { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + q.packets = append(q.packets, p) + return true + } + + if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} + queue.packets[0] = p + expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) + queue.expiration = expiration + if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) { + s.nextZeroRTTCleanup = expiration + } + s.zeroRTTQueues[connID] = queue + return true +} + +func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { + // Iterate over all queues to find those that are expired. + // This is ok since we're placing a pretty low limit on the number of queues. + var nextCleanup time.Time + for connID, q := range s.zeroRTTQueues { + if q.expiration.After(now) { + if nextCleanup.IsZero() || nextCleanup.After(q.expiration) { + nextCleanup = q.expiration + } + continue + } + for _, p := range q.packets { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + p.buffer.Release() + } + delete(s.zeroRTTQueues, connID) + if s.logger.Debug() { + s.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + s.nextZeroRTTCleanup = nextCleanup +} + +// validateToken returns false if: +// - address is invalid +// - token is expired +// - token is null +func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { + if token == nil { + return false + } + if !token.ValidateRemoteAddr(addr) { + return false + } + if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { + return false + } + if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { + return false + } + return true +} + +func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { + if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { + p.buffer.Release() + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + } + return errors.New("too short connection ID") + } + + // The server queues packets for a while, and we might already have established a connection by now. + // This results in a second check in the connection map. + // That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets). + if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok { + handler.handlePacket(p) + return nil + } + + var ( + token *handshake.Token + retrySrcConnID *protocol.ConnectionID + ) + origDestConnID := hdr.DestConnectionID + if len(hdr.Token) > 0 { + tok, err := s.tokenGenerator.DecodeToken(hdr.Token) + if err == nil { + if tok.IsRetryToken { + origDestConnID = tok.OriginalDestConnectionID + retrySrcConnID = &tok.RetrySrcConnectionID + } + token = tok + } + } + + clientAddrIsValid := s.validateToken(token, p.remoteAddr) + if token != nil && !clientAddrIsValid { + // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. + // We just ignore them, and act as if there was no token on this packet at all. + // This also means we might send a Retry later. + if !token.IsRetryToken { + token = nil + } else { + // For Retry tokens, we send an INVALID_ERROR if + // * the token is too old, or + // * the token is invalid, in case of a retry token. + go func() { + defer p.buffer.Release() + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + }() + return nil + } + } + if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { + // Retry invalidates all 0-RTT packets sent. + delete(s.zeroRTTQueues, hdr.DestConnectionID) + go func() { + defer p.buffer.Release() + if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error sending Retry: %s", err) + } + }() + return nil + } + + if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { + s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() + return nil + } + + connID, err := s.connIDGenerator.GenerateConnectionID() + if err != nil { + return err + } + s.logger.Debugf("Changing connection ID to %s.", connID) + var conn quicConn + tracingID := nextConnTracingID() + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { + config := s.config + if s.config.GetConfigForClient != nil { + conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) + if err != nil { + s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") + return nil, false + } + config = populateConfig(conf) + } + var tracer logging.ConnectionTracer + if config.Tracer != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := hdr.DestConnectionID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) + } + conn = s.newConn( + newSendConn(s.conn, p.remoteAddr, p.info), + s.connHandler, + origDestConnID, + retrySrcConnID, + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + s.connIDGenerator, + s.connHandler.GetStatelessResetToken(connID), + config, + s.tlsConf, + s.tokenGenerator, + clientAddrIsValid, + tracer, + tracingID, + s.logger, + hdr.Version, + ) + conn.handlePacket(p) + + if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { + for _, p := range q.packets { + conn.handlePacket(p) + } + delete(s.zeroRTTQueues, hdr.DestConnectionID) + } + + return conn, true + }); !added { + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() + return nil + } + go conn.run() + go s.handleNewConn(conn) + if conn == nil { + p.buffer.Release() + return nil + } + return nil +} + +func (s *baseServer) handleNewConn(conn quicConn) { + connCtx := conn.Context() + if s.acceptEarlyConns { + // wait until the early connection is ready (or the handshake fails) + select { + case <-conn.earlyConnReady(): + case <-connCtx.Done(): + return + } + } else { + // wait until the handshake is complete (or fails) + select { + case <-conn.HandshakeComplete(): + case <-connCtx.Done(): + return + } + } + + atomic.AddInt32(&s.connQueueLen, 1) + select { + case s.connQueue <- conn: + // blocks until the connection is accepted + case <-connCtx.Done(): + atomic.AddInt32(&s.connQueueLen, -1) + // don't pass connections that were already closed to Accept() + } +} + +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + // Log the Initial packet now. + // If no Retry is sent, the packet will be logged by the connection. + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) + srcConnID, err := s.connIDGenerator.GenerateConnectionID() + if err != nil { + return err + } + token, err := s.tokenGenerator.NewRetryToken(remoteAddr, hdr.DestConnectionID, srcConnID) + if err != nil { + return err + } + replyHdr := &wire.ExtendedHeader{} + replyHdr.Type = protocol.PacketTypeRetry + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = srcConnID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.Token = token + if s.logger.Debug() { + s.logger.Debugf("Changing connection ID to %s.", srcConnID) + s.logger.Debugf("-> Sending Retry") + replyHdr.Log(s.logger) + } + + buf := getPacketBuffer() + defer buf.Release() + buf.Data, err = replyHdr.Append(buf.Data, hdr.Version) + if err != nil { + return err + } + // append the Retry integrity tag + tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) + buf.Data = append(buf.Data, tag[:]...) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) + } + _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. + sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + data := p.data[:hdr.ParsedLen()+hdr.Length] + extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + } + // don't return the error here. Just drop the packet. + return nil + } + hdrLen := extHdr.ParsedLen() + if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { + // don't return the error here. Just drop the packet. + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + } + return nil + } + if s.logger.Debug() { + s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) + } + return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) +} + +func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { + sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) + return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) +} + +// sendError sends the error as a response to the packet received with header hdr +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { + b := getPacketBuffer() + defer b.Release() + + ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)} + + replyHdr := &wire.ExtendedHeader{} + replyHdr.Type = protocol.PacketTypeInitial + replyHdr.Version = hdr.Version + replyHdr.SrcConnectionID = hdr.DestConnectionID + replyHdr.DestConnectionID = hdr.SrcConnectionID + replyHdr.PacketNumberLen = protocol.PacketNumberLen4 + replyHdr.Length = 4 /* packet number len */ + ccf.Length(hdr.Version) + protocol.ByteCount(sealer.Overhead()) + var err error + b.Data, err = replyHdr.Append(b.Data, hdr.Version) + if err != nil { + return err + } + payloadOffset := len(b.Data) + + b.Data, err = ccf.Append(b.Data, hdr.Version) + if err != nil { + return err + } + + _ = sealer.Seal(b.Data[payloadOffset:payloadOffset], b.Data[payloadOffset:], replyHdr.PacketNumber, b.Data[:payloadOffset]) + b.Data = b.Data[0 : len(b.Data)+sealer.Overhead()] + + pnOffset := payloadOffset - int(replyHdr.PacketNumberLen) + sealer.EncryptHeader( + b.Data[pnOffset+4:pnOffset+4+16], + &b.Data[0], + b.Data[pnOffset:payloadOffset], + ) + + replyHdr.Log(s.logger) + wire.LogFrame(s.logger, ccf, true) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) + } + _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + return err +} + +func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte, v protocol.VersionNumber) { + s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) + + data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) + if s.tracer != nil { + s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) + } + if _, err := s.conn.WritePacket(data, remote, oob); err != nil { + s.logger.Debugf("Error sending Version Negotiation: %s", err) + } +} diff --git a/vendor/github.com/lucas-clemente/quic-go/stream.go b/vendor/github.com/quic-go/quic-go/stream.go similarity index 89% rename from vendor/github.com/lucas-clemente/quic-go/stream.go rename to vendor/github.com/quic-go/quic-go/stream.go index 95bbcb35651..98d2fc6e47b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/stream.go +++ b/vendor/github.com/quic-go/quic-go/stream.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go/internal/ackhandler" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) type deadlineError struct{} @@ -60,7 +60,7 @@ type streamI interface { // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) updateSendWindow(protocol.ByteCount) } @@ -80,8 +80,6 @@ type stream struct { sender streamSender receiveStreamCompleted bool sendStreamCompleted bool - - version protocol.VersionNumber } var _ Stream = &stream{} @@ -90,9 +88,8 @@ var _ Stream = &stream{} func newStream(streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, - version protocol.VersionNumber, ) *stream { - s := &stream{sender: sender, version: version} + s := &stream{sender: sender} senderForSendStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { @@ -102,7 +99,7 @@ func newStream(streamID protocol.StreamID, s.completedMutex.Unlock() }, } - s.sendStream = *newSendStream(streamID, senderForSendStream, flowController, version) + s.sendStream = *newSendStream(streamID, senderForSendStream, flowController) senderForReceiveStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { @@ -112,7 +109,7 @@ func newStream(streamID protocol.StreamID, s.completedMutex.Unlock() }, } - s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController, version) + s.receiveStream = *newReceiveStream(streamID, senderForReceiveStream, flowController) return s } diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map.go b/vendor/github.com/quic-go/quic-go/streams_map.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/streams_map.go rename to vendor/github.com/quic-go/quic-go/streams_map.go index 79c1ee91a86..b1a80eb36fa 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map.go +++ b/vendor/github.com/quic-go/quic-go/streams_map.go @@ -7,10 +7,10 @@ import ( "net" "sync" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/qerr" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/wire" ) type streamError struct { @@ -46,7 +46,6 @@ var errTooManyOpenStreams = errors.New("too many open streams") type streamsMap struct { perspective protocol.Perspective - version protocol.VersionNumber maxIncomingBidiStreams uint64 maxIncomingUniStreams uint64 @@ -55,10 +54,10 @@ type streamsMap struct { newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController mutex sync.Mutex - outgoingBidiStreams *outgoingBidiStreamsMap - outgoingUniStreams *outgoingUniStreamsMap - incomingBidiStreams *incomingBidiStreamsMap - incomingUniStreams *incomingUniStreamsMap + outgoingBidiStreams *outgoingStreamsMap[streamI] + outgoingUniStreams *outgoingStreamsMap[sendStreamI] + incomingBidiStreams *incomingStreamsMap[streamI] + incomingUniStreams *incomingStreamsMap[receiveStreamI] reset bool } @@ -70,7 +69,6 @@ func newStreamsMap( maxIncomingBidiStreams uint64, maxIncomingUniStreams uint64, perspective protocol.Perspective, - version protocol.VersionNumber, ) streamManager { m := &streamsMap{ perspective: perspective, @@ -78,39 +76,42 @@ func newStreamsMap( maxIncomingBidiStreams: maxIncomingBidiStreams, maxIncomingUniStreams: maxIncomingUniStreams, sender: sender, - version: version, } m.initMaps() return m } func (m *streamsMap) initMaps() { - m.outgoingBidiStreams = newOutgoingBidiStreamsMap( + m.outgoingBidiStreams = newOutgoingStreamsMap( + protocol.StreamTypeBidi, func(num protocol.StreamNum) streamI { id := num.StreamID(protocol.StreamTypeBidi, m.perspective) - return newStream(id, m.sender, m.newFlowController(id), m.version) + return newStream(id, m.sender, m.newFlowController(id)) }, m.sender.queueControlFrame, ) - m.incomingBidiStreams = newIncomingBidiStreamsMap( + m.incomingBidiStreams = newIncomingStreamsMap( + protocol.StreamTypeBidi, func(num protocol.StreamNum) streamI { id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) - return newStream(id, m.sender, m.newFlowController(id), m.version) + return newStream(id, m.sender, m.newFlowController(id)) }, m.maxIncomingBidiStreams, m.sender.queueControlFrame, ) - m.outgoingUniStreams = newOutgoingUniStreamsMap( + m.outgoingUniStreams = newOutgoingStreamsMap( + protocol.StreamTypeUni, func(num protocol.StreamNum) sendStreamI { id := num.StreamID(protocol.StreamTypeUni, m.perspective) - return newSendStream(id, m.sender, m.newFlowController(id), m.version) + return newSendStream(id, m.sender, m.newFlowController(id)) }, m.sender.queueControlFrame, ) - m.incomingUniStreams = newIncomingUniStreamsMap( + m.incomingUniStreams = newIncomingStreamsMap( + protocol.StreamTypeUni, func(num protocol.StreamNum) receiveStreamI { id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) - return newReceiveStream(id, m.sender, m.newFlowController(id), m.version) + return newReceiveStream(id, m.sender, m.newFlowController(id)) }, m.maxIncomingUniStreams, m.sender.queueControlFrame, diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go b/vendor/github.com/quic-go/quic-go/streams_map_incoming.go similarity index 76% rename from vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go rename to vendor/github.com/quic-go/quic-go/streams_map_incoming.go index 5bddec00b23..18ec6f998b0 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_incoming_uni.go +++ b/vendor/github.com/quic-go/quic-go/streams_map_incoming.go @@ -1,49 +1,52 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - package quic import ( "context" "sync" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) +type incomingStream interface { + closeForShutdown(error) +} + // When a stream is deleted before it was accepted, we can't delete it from the map immediately. // We need to wait until the application accepts it, and delete it then. -type receiveStreamIEntry struct { - stream receiveStreamI +type incomingStreamEntry[T incomingStream] struct { + stream T shouldDelete bool } -type incomingUniStreamsMap struct { +type incomingStreamsMap[T incomingStream] struct { mutex sync.RWMutex newStreamChan chan struct{} - streams map[protocol.StreamNum]receiveStreamIEntry + streamType protocol.StreamType + streams map[protocol.StreamNum]incomingStreamEntry[T] nextStreamToAccept protocol.StreamNum // the next stream that will be returned by AcceptStream() nextStreamToOpen protocol.StreamNum // the highest stream that the peer opened maxStream protocol.StreamNum // the highest stream that the peer is allowed to open maxNumStreams uint64 // maximum number of streams - newStream func(protocol.StreamNum) receiveStreamI + newStream func(protocol.StreamNum) T queueMaxStreamID func(*wire.MaxStreamsFrame) closeErr error } -func newIncomingUniStreamsMap( - newStream func(protocol.StreamNum) receiveStreamI, +func newIncomingStreamsMap[T incomingStream]( + streamType protocol.StreamType, + newStream func(protocol.StreamNum) T, maxStreams uint64, queueControlFrame func(wire.Frame), -) *incomingUniStreamsMap { - return &incomingUniStreamsMap{ +) *incomingStreamsMap[T] { + return &incomingStreamsMap[T]{ newStreamChan: make(chan struct{}, 1), - streams: make(map[protocol.StreamNum]receiveStreamIEntry), + streamType: streamType, + streams: make(map[protocol.StreamNum]incomingStreamEntry[T]), maxStream: protocol.StreamNum(maxStreams), maxNumStreams: maxStreams, newStream: newStream, @@ -53,7 +56,7 @@ func newIncomingUniStreamsMap( } } -func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStreamI, error) { +func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) { // drain the newStreamChan, so we don't check the map twice if the stream doesn't exist select { case <-m.newStreamChan: @@ -63,12 +66,12 @@ func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStream m.mutex.Lock() var num protocol.StreamNum - var entry receiveStreamIEntry + var entry incomingStreamEntry[T] for { num = m.nextStreamToAccept if m.closeErr != nil { m.mutex.Unlock() - return nil, m.closeErr + return *new(T), m.closeErr } var ok bool entry, ok = m.streams[num] @@ -78,7 +81,7 @@ func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStream m.mutex.Unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return *new(T), ctx.Err() case <-m.newStreamChan: } m.mutex.Lock() @@ -88,18 +91,18 @@ func (m *incomingUniStreamsMap) AcceptStream(ctx context.Context) (receiveStream if entry.shouldDelete { if err := m.deleteStream(num); err != nil { m.mutex.Unlock() - return nil, err + return *new(T), err } } m.mutex.Unlock() return entry.stream, nil } -func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receiveStreamI, error) { +func (m *incomingStreamsMap[T]) GetOrOpenStream(num protocol.StreamNum) (T, error) { m.mutex.RLock() if num > m.maxStream { m.mutex.RUnlock() - return nil, streamError{ + return *new(T), streamError{ message: "peer tried to open stream %d (current limit: %d)", nums: []protocol.StreamNum{num, m.maxStream}, } @@ -108,7 +111,7 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive // * this stream exists in the map, and we can return it, or // * this stream was already closed, then we can return the nil if num < m.nextStreamToOpen { - var s receiveStreamI + var s T // If the stream was already queued for deletion, and is just waiting to be accepted, don't return it. if entry, ok := m.streams[num]; ok && !entry.shouldDelete { s = entry.stream @@ -123,7 +126,7 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive // * maxStream can only increase, so if the id was valid before, it definitely is valid now // * highestStream is only modified by this function for newNum := m.nextStreamToOpen; newNum <= num; newNum++ { - m.streams[newNum] = receiveStreamIEntry{stream: m.newStream(newNum)} + m.streams[newNum] = incomingStreamEntry[T]{stream: m.newStream(newNum)} select { case m.newStreamChan <- struct{}{}: default: @@ -135,14 +138,14 @@ func (m *incomingUniStreamsMap) GetOrOpenStream(num protocol.StreamNum) (receive return entry.stream, nil } -func (m *incomingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { +func (m *incomingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() return m.deleteStream(num) } -func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { +func (m *incomingStreamsMap[T]) deleteStream(num protocol.StreamNum) error { if _, ok := m.streams[num]; !ok { return streamError{ message: "tried to delete unknown incoming stream %d", @@ -173,7 +176,7 @@ func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { if maxStream <= protocol.MaxStreamCount { m.maxStream = maxStream m.queueMaxStreamID(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeUni, + Type: m.streamType, MaxStreamNum: m.maxStream, }) } @@ -181,7 +184,7 @@ func (m *incomingUniStreamsMap) deleteStream(num protocol.StreamNum) error { return nil } -func (m *incomingUniStreamsMap) CloseWithError(err error) { +func (m *incomingStreamsMap[T]) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err for _, entry := range m.streams { diff --git a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go b/vendor/github.com/quic-go/quic-go/streams_map_outgoing.go similarity index 73% rename from vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go rename to vendor/github.com/quic-go/quic-go/streams_map_outgoing.go index 8782364a54a..fd45f4e7cf1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/streams_map_outgoing_uni.go +++ b/vendor/github.com/quic-go/quic-go/streams_map_outgoing.go @@ -1,21 +1,23 @@ -// This file was automatically generated by genny. -// Any changes will be lost if this file is regenerated. -// see https://github.com/cheekybits/genny - package quic import ( "context" "sync" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) -type outgoingUniStreamsMap struct { +type outgoingStream interface { + updateSendWindow(protocol.ByteCount) + closeForShutdown(error) +} + +type outgoingStreamsMap[T outgoingStream] struct { mutex sync.RWMutex - streams map[protocol.StreamNum]sendStreamI + streamType protocol.StreamType + streams map[protocol.StreamNum]T openQueue map[uint64]chan struct{} lowestInQueue uint64 @@ -25,18 +27,20 @@ type outgoingUniStreamsMap struct { maxStream protocol.StreamNum // the maximum stream ID we're allowed to open blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream - newStream func(protocol.StreamNum) sendStreamI + newStream func(protocol.StreamNum) T queueStreamIDBlocked func(*wire.StreamsBlockedFrame) closeErr error } -func newOutgoingUniStreamsMap( - newStream func(protocol.StreamNum) sendStreamI, +func newOutgoingStreamsMap[T outgoingStream]( + streamType protocol.StreamType, + newStream func(protocol.StreamNum) T, queueControlFrame func(wire.Frame), -) *outgoingUniStreamsMap { - return &outgoingUniStreamsMap{ - streams: make(map[protocol.StreamNum]sendStreamI), +) *outgoingStreamsMap[T] { + return &outgoingStreamsMap[T]{ + streamType: streamType, + streams: make(map[protocol.StreamNum]T), openQueue: make(map[uint64]chan struct{}), maxStream: protocol.InvalidStreamNum, nextStream: 1, @@ -45,32 +49,32 @@ func newOutgoingUniStreamsMap( } } -func (m *outgoingUniStreamsMap) OpenStream() (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) OpenStream() (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } // if there are OpenStreamSync calls waiting, return an error here if len(m.openQueue) > 0 || m.nextStream > m.maxStream { m.maybeSendBlockedFrame() - return nil, streamOpenErr{errTooManyOpenStreams} + return *new(T), streamOpenErr{errTooManyOpenStreams} } return m.openStream(), nil } -func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { m.mutex.Lock() defer m.mutex.Unlock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } if err := ctx.Err(); err != nil { - return nil, err + return *new(T), err } if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { @@ -92,13 +96,13 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI case <-ctx.Done(): m.mutex.Lock() delete(m.openQueue, queuePos) - return nil, ctx.Err() + return *new(T), ctx.Err() case <-waitChan: } m.mutex.Lock() if m.closeErr != nil { - return nil, m.closeErr + return *new(T), m.closeErr } if m.nextStream > m.maxStream { // no stream available. Continue waiting @@ -112,7 +116,7 @@ func (m *outgoingUniStreamsMap) OpenStreamSync(ctx context.Context) (sendStreamI } } -func (m *outgoingUniStreamsMap) openStream() sendStreamI { +func (m *outgoingStreamsMap[T]) openStream() T { s := m.newStream(m.nextStream) m.streams[m.nextStream] = s m.nextStream++ @@ -121,7 +125,7 @@ func (m *outgoingUniStreamsMap) openStream() sendStreamI { // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, // if we haven't sent one for this offset yet -func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { +func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() { if m.blockedSent { return } @@ -131,17 +135,17 @@ func (m *outgoingUniStreamsMap) maybeSendBlockedFrame() { streamNum = m.maxStream } m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ - Type: protocol.StreamTypeUni, + Type: m.streamType, StreamLimit: streamNum, }) m.blockedSent = true } -func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, error) { +func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) { m.mutex.RLock() if num >= m.nextStream { m.mutex.RUnlock() - return nil, streamError{ + return *new(T), streamError{ message: "peer attempted to open stream %d", nums: []protocol.StreamNum{num}, } @@ -151,7 +155,7 @@ func (m *outgoingUniStreamsMap) GetStream(num protocol.StreamNum) (sendStreamI, return s, nil } -func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { +func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -165,7 +169,7 @@ func (m *outgoingUniStreamsMap) DeleteStream(num protocol.StreamNum) error { return nil } -func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { +func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) { m.mutex.Lock() defer m.mutex.Unlock() @@ -183,7 +187,7 @@ func (m *outgoingUniStreamsMap) SetMaxStream(num protocol.StreamNum) { // UpdateSendWindow is called when the peer's transport parameters are received. // Only in the case of a 0-RTT handshake will we have open streams at this point. // We might need to update the send window, in case the server increased it. -func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { +func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) { m.mutex.Lock() for _, str := range m.streams { str.updateSendWindow(limit) @@ -192,7 +196,7 @@ func (m *outgoingUniStreamsMap) UpdateSendWindow(limit protocol.ByteCount) { } // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream -func (m *outgoingUniStreamsMap) unblockOpenSync() { +func (m *outgoingStreamsMap[T]) unblockOpenSync() { if len(m.openQueue) == 0 { return } @@ -211,7 +215,7 @@ func (m *outgoingUniStreamsMap) unblockOpenSync() { } } -func (m *outgoingUniStreamsMap) CloseWithError(err error) { +func (m *outgoingStreamsMap[T]) CloseWithError(err error) { m.mutex.Lock() m.closeErr = err for _, str := range m.streams { diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn.go b/vendor/github.com/quic-go/quic-go/sys_conn.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn.go rename to vendor/github.com/quic-go/quic-go/sys_conn.go index 7cc05465808..d6c1d616453 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn.go @@ -5,8 +5,8 @@ import ( "syscall" "time" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) // OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go b/vendor/github.com/quic-go/quic-go/sys_conn_df.go similarity index 90% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go rename to vendor/github.com/quic-go/quic-go/sys_conn_df.go index ae9274d97fa..ef9f981ac63 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df.go @@ -1,5 +1,4 @@ //go:build !linux && !windows -// +build !linux,!windows package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go similarity index 88% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go rename to vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go index 69f8fc93202..98542b4102b 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_linux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go @@ -1,5 +1,4 @@ //go:build linux -// +build linux package quic @@ -9,7 +8,7 @@ import ( "golang.org/x/sys/unix" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils" ) func setDF(rawConn syscall.RawConn) error { @@ -30,7 +29,7 @@ func setDF(rawConn syscall.RawConn) error { case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: - utils.DefaultLogger.Errorf("setting DF failed for both IPv4 and IPv6") + return errors.New("setting DF failed for both IPv4 and IPv6") } return nil } diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go rename to vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go index 4649f6463d2..9855e8de8d6 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_df_windows.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go @@ -1,5 +1,4 @@ //go:build windows -// +build windows package quic @@ -7,8 +6,9 @@ import ( "errors" "syscall" - "github.com/lucas-clemente/quic-go/internal/utils" "golang.org/x/sys/windows" + + "github.com/quic-go/quic-go/internal/utils" ) const ( diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go similarity index 95% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go rename to vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go index eabf489f109..7ad5f3af163 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_darwin.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_darwin.go @@ -1,5 +1,4 @@ //go:build darwin -// +build darwin package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go similarity index 93% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go rename to vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go index 0b3e8434b8a..8d16d0b9103 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_freebsd.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_freebsd.go @@ -1,5 +1,4 @@ //go:build freebsd -// +build freebsd package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go similarity index 96% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go rename to vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go index 51bec900242..61c3f54ba0c 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_helper_linux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go @@ -1,5 +1,4 @@ //go:build linux -// +build linux package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go similarity index 87% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go rename to vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go index e3b0d11f685..7ab5040aa19 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_no_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go @@ -1,5 +1,4 @@ //go:build !darwin && !linux && !freebsd && !windows -// +build !darwin,!linux,!freebsd,!windows package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go similarity index 87% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go rename to vendor/github.com/quic-go/quic-go/sys_conn_oob.go index acd74d023c1..806dfb81a32 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go @@ -1,5 +1,4 @@ //go:build darwin || linux || freebsd -// +build darwin linux freebsd package quic @@ -15,8 +14,8 @@ import ( "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" ) const ( @@ -123,10 +122,15 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) { bc = ipv4.NewPacketConn(c) } + msgs := make([]ipv4.Message, batchSize) + for i := range msgs { + // preallocate the [][]byte + msgs[i].Buffers = make([][]byte, 1) + } oobConn := &oobConn{ OOBCapablePacketConn: c, batchConn: bc, - messages: make([]ipv4.Message, batchSize), + messages: msgs, readPos: batchSize, } for i := 0; i < batchSize; i++ { @@ -143,7 +147,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { buffer := getPacketBuffer() buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] c.buffers[i] = buffer - c.messages[i].Buffers = [][]byte{c.buffers[i].Data} + c.messages[i].Buffers[0] = c.buffers[i].Data } c.readPos = 0 @@ -157,18 +161,20 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { msg := c.messages[c.readPos] buffer := c.buffers[c.readPos] c.readPos++ - ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN]) - if err != nil { - return nil, err - } + + data := msg.OOB[:msg.NN] var ecn protocol.ECN var destIP net.IP var ifIndex uint32 - for _, ctrlMsg := range ctrlMsgs { - if ctrlMsg.Header.Level == unix.IPPROTO_IP { - switch ctrlMsg.Header.Type { + for len(data) > 0 { + hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) + if err != nil { + return nil, err + } + if hdr.Level == unix.IPPROTO_IP { + switch hdr.Type { case msgTypeIPTOS: - ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv4PKTINFO: // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ @@ -177,33 +183,34 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { // address */ // }; ip := make([]byte, 4) - if len(ctrlMsg.Data) == 12 { - ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data) - copy(ip, ctrlMsg.Data[8:12]) - } else if len(ctrlMsg.Data) == 4 { + if len(body) == 12 { + ifIndex = binary.LittleEndian.Uint32(body) + copy(ip, body[8:12]) + } else if len(body) == 4 { // FreeBSD - copy(ip, ctrlMsg.Data) + copy(ip, body) } destIP = net.IP(ip) } } - if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 { - switch ctrlMsg.Header.Type { + if hdr.Level == unix.IPPROTO_IPV6 { + switch hdr.Type { case unix.IPV6_TCLASS: - ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask) + ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv6PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; - if len(ctrlMsg.Data) == 20 { + if len(body) == 20 { ip := make([]byte, 16) - copy(ip, ctrlMsg.Data[:16]) + copy(ip, body[:16]) destIP = net.IP(ip) - ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:]) + ifIndex = binary.LittleEndian.Uint32(body[16:]) } } } + data = remainder } var info *packetInfo if destIP != nil { diff --git a/vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go b/vendor/github.com/quic-go/quic-go/sys_conn_windows.go similarity index 97% rename from vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go rename to vendor/github.com/quic-go/quic-go/sys_conn_windows.go index f2cc22ab7c4..b003fe94afe 100644 --- a/vendor/github.com/lucas-clemente/quic-go/sys_conn_windows.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_windows.go @@ -1,5 +1,4 @@ //go:build windows -// +build windows package quic diff --git a/vendor/github.com/lucas-clemente/quic-go/token_store.go b/vendor/github.com/quic-go/quic-go/token_store.go similarity index 84% rename from vendor/github.com/lucas-clemente/quic-go/token_store.go rename to vendor/github.com/quic-go/quic-go/token_store.go index 9641dc5a7e6..00460e50285 100644 --- a/vendor/github.com/lucas-clemente/quic-go/token_store.go +++ b/vendor/github.com/quic-go/quic-go/token_store.go @@ -1,10 +1,10 @@ package quic import ( - "container/list" "sync" - "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/utils" + list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) type singleOriginTokenStore struct { @@ -48,8 +48,8 @@ type lruTokenStoreEntry struct { type lruTokenStore struct { mutex sync.Mutex - m map[string]*list.Element - q *list.List + m map[string]*list.Element[*lruTokenStoreEntry] + q *list.List[*lruTokenStoreEntry] capacity int singleOriginSize int } @@ -61,8 +61,8 @@ var _ TokenStore = &lruTokenStore{} // tokensPerOrigin specifies the maximum number of tokens per origin. func NewLRUTokenStore(maxOrigins, tokensPerOrigin int) TokenStore { return &lruTokenStore{ - m: make(map[string]*list.Element), - q: list.New(), + m: make(map[string]*list.Element[*lruTokenStoreEntry]), + q: list.New[*lruTokenStoreEntry](), capacity: maxOrigins, singleOriginSize: tokensPerOrigin, } @@ -73,7 +73,7 @@ func (s *lruTokenStore) Put(key string, token *ClientToken) { defer s.mutex.Unlock() if el, ok := s.m[key]; ok { - entry := el.Value.(*lruTokenStoreEntry) + entry := el.Value entry.cache.Add(token) s.q.MoveToFront(el) return @@ -90,7 +90,7 @@ func (s *lruTokenStore) Put(key string, token *ClientToken) { } elem := s.q.Back() - entry := elem.Value.(*lruTokenStoreEntry) + entry := elem.Value delete(s.m, entry.key) entry.key = key entry.cache = newSingleOriginTokenStore(s.singleOriginSize) @@ -106,7 +106,7 @@ func (s *lruTokenStore) Pop(key string) *ClientToken { var token *ClientToken if el, ok := s.m[key]; ok { s.q.MoveToFront(el) - cache := el.Value.(*lruTokenStoreEntry).cache + cache := el.Value.cache token = cache.Pop() if cache.Len() == 0 { s.q.Remove(el) diff --git a/vendor/github.com/quic-go/quic-go/tools.go b/vendor/github.com/quic-go/quic-go/tools.go new file mode 100644 index 00000000000..e848317f15d --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/tools.go @@ -0,0 +1,8 @@ +//go:build tools + +package quic + +import ( + _ "github.com/golang/mock/mockgen" + _ "github.com/onsi/ginkgo/v2/ginkgo" +) diff --git a/vendor/github.com/quic-go/quic-go/transport.go b/vendor/github.com/quic-go/quic-go/transport.go new file mode 100644 index 00000000000..153675da6df --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/transport.go @@ -0,0 +1,414 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/quic-go/quic-go/internal/wire" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +type Transport struct { + // A single net.PacketConn can only be handled by one Transport. + // Bad things will happen if passed to multiple Transports. + // + // If the connection satisfies the OOBCapablePacketConn interface + // (as a net.UDPConn does), ECN and packet info support will be enabled. + // In this case, optimized syscalls might be used, skipping the + // ReadFrom and WriteTo calls to read / write packets. + Conn net.PacketConn + + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If unset, a 4 byte connection ID will be used. + ConnectionIDLength int + + // Use for generating new connection IDs. + // This allows the application to control of the connection IDs used, + // which allows routing / load balancing based on connection IDs. + // All Connection IDs returned by the ConnectionIDGenerator MUST + // have the same length. + ConnectionIDGenerator ConnectionIDGenerator + + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + StatelessResetKey *StatelessResetKey + + // A Tracer traces events that don't belong to a single QUIC connection. + Tracer logging.Tracer + + handlerMap packetHandlerManager + + mutex sync.Mutex + initOnce sync.Once + initErr error + + // Set in init. + // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. + connIDLen int + // Set in init. + // If no ConnectionIDGenerator is set, this is set to a default. + connIDGenerator ConnectionIDGenerator + + server unknownPacketHandler + + conn rawConn + + closeQueue chan closePacket + + listening chan struct{} // is closed when listen returns + closed bool + createdConn bool + isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + + logger utils.Logger +} + +// Listen starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(true); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + if err != nil { + return nil, err + } + t.server = s + return &Listener{baseServer: s}, nil +} + +// ListenEarly starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(true); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) + if err != nil { + return nil, err + } + t.server = s + return &EarlyListener{baseServer: s}, nil +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + if err := t.init(false); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + if err := t.init(false); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, newSendConn(t.conn, addr, nil), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) +} + +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + size, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { + return fmt.Errorf("failed to increase receive buffer size: %w", err) + } + newSize, err := inspectReadBuffer(c) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} + +// only print warnings about the UDP receive buffer size once +var receiveBufferWarningOnce sync.Once + +func (t *Transport) init(isServer bool) error { + t.initOnce.Do(func() { + getMultiplexer().AddConn(t.Conn) + + conn, err := wrapConn(t.Conn) + if err != nil { + t.initErr = err + return + } + + t.logger = utils.DefaultLogger // TODO: make this configurable + t.conn = conn + t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + t.listening = make(chan struct{}) + + t.closeQueue = make(chan closePacket, 4) + + if t.ConnectionIDGenerator != nil { + t.connIDGenerator = t.ConnectionIDGenerator + t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() + } else { + connIDLen := t.ConnectionIDLength + if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) { + connIDLen = protocol.DefaultConnectionIDLength + } + t.connIDLen = connIDLen + t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} + } + + go t.listen(conn) + go t.runCloseQueue() + }) + return t.initErr +} + +func (t *Transport) enqueueClosePacket(p closePacket) { + select { + case t.closeQueue <- p: + default: + // Oops, we're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } +} + +func (t *Transport) runCloseQueue() { + for { + select { + case <-t.listening: + return + case p := <-t.closeQueue: + t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + } + } +} + +// Close closes the underlying connection and waits until listen has returned. +// It is invalid to start new listeners or connections after that. +func (t *Transport) Close() error { + t.close(errors.New("closing")) + if t.createdConn { + if err := t.conn.Close(); err != nil { + return err + } + } else { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + } + <-t.listening // wait until listening returns + return nil +} + +func (t *Transport) closeServer() { + t.handlerMap.CloseServer() + t.mutex.Lock() + t.server = nil + if t.isSingleUse { + t.closed = true + } + t.mutex.Unlock() + if t.createdConn { + t.Conn.Close() + } + if t.isSingleUse { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + <-t.listening // wait until listening returns + } +} + +func (t *Transport) close(e error) { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.closed { + return + } + + t.handlerMap.Close(e) + if t.server != nil { + t.server.setCloseError(e) + } + t.closed = true +} + +func (t *Transport) listen(conn rawConn) { + defer close(t.listening) + defer getMultiplexer().RemoveConn(t.Conn) + + if err := setReceiveBuffer(t.Conn, t.logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } + } + + for { + p, err := conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/quic-go/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + t.mutex.Lock() + closed := t.closed + t.mutex.Unlock() + if closed { + return + } + t.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + t.close(err) + return + } + t.handlePacket(p) + } +} + +func (t *Transport) handlePacket(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, t.connIDLen) + if err != nil { + t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + if handler, ok := t.handlerMap.Get(connID); ok { + handler.handlePacket(p) + return + } + if !wire.IsLongHeaderPacket(p.data[0]) { + go t.maybeSendStatelessReset(p, connID) + return + } + + t.mutex.Lock() + defer t.mutex.Unlock() + if t.server == nil { // no server set + t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + t.server.handlePacket(p) +} + +func (t *Transport) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + if t.StatelessResetKey == nil { + return + } + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } + token := t.handlerMap.GetStatelessResetToken(connID) + t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + t.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} + +func (t *Transport) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if wire.IsLongHeaderPacket(data[0]) { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) + if conn, ok := t.handlerMap.GetByResetToken(token); ok { + t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go conn.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} diff --git a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go b/vendor/github.com/quic-go/quic-go/window_update_queue.go similarity index 91% rename from vendor/github.com/lucas-clemente/quic-go/window_update_queue.go rename to vendor/github.com/quic-go/quic-go/window_update_queue.go index 2abcf67390a..9ed121430e1 100644 --- a/vendor/github.com/lucas-clemente/quic-go/window_update_queue.go +++ b/vendor/github.com/quic-go/quic-go/window_update_queue.go @@ -3,9 +3,9 @@ package quic import ( "sync" - "github.com/lucas-clemente/quic-go/internal/flowcontrol" - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/quic-go/quic-go/internal/flowcontrol" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" ) type windowUpdateQueue struct { diff --git a/vendor/github.com/marten-seemann/qtls-go1-18/LICENSE b/vendor/golang.org/x/exp/LICENSE similarity index 100% rename from vendor/github.com/marten-seemann/qtls-go1-18/LICENSE rename to vendor/golang.org/x/exp/LICENSE diff --git a/vendor/golang.org/x/exp/PATENTS b/vendor/golang.org/x/exp/PATENTS new file mode 100644 index 00000000000..733099041f8 --- /dev/null +++ b/vendor/golang.org/x/exp/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/exp/constraints/constraints.go b/vendor/golang.org/x/exp/constraints/constraints.go new file mode 100644 index 00000000000..2c033dff47e --- /dev/null +++ b/vendor/golang.org/x/exp/constraints/constraints.go @@ -0,0 +1,50 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package constraints defines a set of useful constraints to be used +// with type parameters. +package constraints + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Complex is a constraint that permits any complex numeric type. +// If future releases of Go add new predeclared complex numeric types, +// this constraint will be modified to include them. +type Complex interface { + ~complex64 | ~complex128 +} + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} diff --git a/vendor/golang.org/x/mod/modfile/print.go b/vendor/golang.org/x/mod/modfile/print.go new file mode 100644 index 00000000000..524f93022ac --- /dev/null +++ b/vendor/golang.org/x/mod/modfile/print.go @@ -0,0 +1,174 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Module file printer. + +package modfile + +import ( + "bytes" + "fmt" + "strings" +) + +// Format returns a go.mod file as a byte slice, formatted in standard style. +func Format(f *FileSyntax) []byte { + pr := &printer{} + pr.file(f) + return pr.Bytes() +} + +// A printer collects the state during printing of a file or expression. +type printer struct { + bytes.Buffer // output buffer + comment []Comment // pending end-of-line comments + margin int // left margin (indent), a number of tabs +} + +// printf prints to the buffer. +func (p *printer) printf(format string, args ...interface{}) { + fmt.Fprintf(p, format, args...) +} + +// indent returns the position on the current line, in bytes, 0-indexed. +func (p *printer) indent() int { + b := p.Bytes() + n := 0 + for n < len(b) && b[len(b)-1-n] != '\n' { + n++ + } + return n +} + +// newline ends the current line, flushing end-of-line comments. +func (p *printer) newline() { + if len(p.comment) > 0 { + p.printf(" ") + for i, com := range p.comment { + if i > 0 { + p.trim() + p.printf("\n") + for i := 0; i < p.margin; i++ { + p.printf("\t") + } + } + p.printf("%s", strings.TrimSpace(com.Token)) + } + p.comment = p.comment[:0] + } + + p.trim() + p.printf("\n") + for i := 0; i < p.margin; i++ { + p.printf("\t") + } +} + +// trim removes trailing spaces and tabs from the current line. +func (p *printer) trim() { + // Remove trailing spaces and tabs from line we're about to end. + b := p.Bytes() + n := len(b) + for n > 0 && (b[n-1] == '\t' || b[n-1] == ' ') { + n-- + } + p.Truncate(n) +} + +// file formats the given file into the print buffer. +func (p *printer) file(f *FileSyntax) { + for _, com := range f.Before { + p.printf("%s", strings.TrimSpace(com.Token)) + p.newline() + } + + for i, stmt := range f.Stmt { + switch x := stmt.(type) { + case *CommentBlock: + // comments already handled + p.expr(x) + + default: + p.expr(x) + p.newline() + } + + for _, com := range stmt.Comment().After { + p.printf("%s", strings.TrimSpace(com.Token)) + p.newline() + } + + if i+1 < len(f.Stmt) { + p.newline() + } + } +} + +func (p *printer) expr(x Expr) { + // Emit line-comments preceding this expression. + if before := x.Comment().Before; len(before) > 0 { + // Want to print a line comment. + // Line comments must be at the current margin. + p.trim() + if p.indent() > 0 { + // There's other text on the line. Start a new line. + p.printf("\n") + } + // Re-indent to margin. + for i := 0; i < p.margin; i++ { + p.printf("\t") + } + for _, com := range before { + p.printf("%s", strings.TrimSpace(com.Token)) + p.newline() + } + } + + switch x := x.(type) { + default: + panic(fmt.Errorf("printer: unexpected type %T", x)) + + case *CommentBlock: + // done + + case *LParen: + p.printf("(") + case *RParen: + p.printf(")") + + case *Line: + p.tokens(x.Token) + + case *LineBlock: + p.tokens(x.Token) + p.printf(" ") + p.expr(&x.LParen) + p.margin++ + for _, l := range x.Line { + p.newline() + p.expr(l) + } + p.margin-- + p.newline() + p.expr(&x.RParen) + } + + // Queue end-of-line comments for printing when we + // reach the end of the line. + p.comment = append(p.comment, x.Comment().Suffix...) +} + +func (p *printer) tokens(tokens []string) { + sep := "" + for _, t := range tokens { + if t == "," || t == ")" || t == "]" || t == "}" { + sep = "" + } + p.printf("%s%s", sep, t) + sep = " " + if t == "(" || t == "[" || t == "{" { + sep = "" + } + } +} diff --git a/vendor/golang.org/x/mod/modfile/read.go b/vendor/golang.org/x/mod/modfile/read.go new file mode 100644 index 00000000000..a503bc2105d --- /dev/null +++ b/vendor/golang.org/x/mod/modfile/read.go @@ -0,0 +1,958 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package modfile + +import ( + "bytes" + "errors" + "fmt" + "os" + "strconv" + "strings" + "unicode" + "unicode/utf8" +) + +// A Position describes an arbitrary source position in a file, including the +// file, line, column, and byte offset. +type Position struct { + Line int // line in input (starting at 1) + LineRune int // rune in line (starting at 1) + Byte int // byte in input (starting at 0) +} + +// add returns the position at the end of s, assuming it starts at p. +func (p Position) add(s string) Position { + p.Byte += len(s) + if n := strings.Count(s, "\n"); n > 0 { + p.Line += n + s = s[strings.LastIndex(s, "\n")+1:] + p.LineRune = 1 + } + p.LineRune += utf8.RuneCountInString(s) + return p +} + +// An Expr represents an input element. +type Expr interface { + // Span returns the start and end position of the expression, + // excluding leading or trailing comments. + Span() (start, end Position) + + // Comment returns the comments attached to the expression. + // This method would normally be named 'Comments' but that + // would interfere with embedding a type of the same name. + Comment() *Comments +} + +// A Comment represents a single // comment. +type Comment struct { + Start Position + Token string // without trailing newline + Suffix bool // an end of line (not whole line) comment +} + +// Comments collects the comments associated with an expression. +type Comments struct { + Before []Comment // whole-line comments before this expression + Suffix []Comment // end-of-line comments after this expression + + // For top-level expressions only, After lists whole-line + // comments following the expression. + After []Comment +} + +// Comment returns the receiver. This isn't useful by itself, but +// a Comments struct is embedded into all the expression +// implementation types, and this gives each of those a Comment +// method to satisfy the Expr interface. +func (c *Comments) Comment() *Comments { + return c +} + +// A FileSyntax represents an entire go.mod file. +type FileSyntax struct { + Name string // file path + Comments + Stmt []Expr +} + +func (x *FileSyntax) Span() (start, end Position) { + if len(x.Stmt) == 0 { + return + } + start, _ = x.Stmt[0].Span() + _, end = x.Stmt[len(x.Stmt)-1].Span() + return start, end +} + +// addLine adds a line containing the given tokens to the file. +// +// If the first token of the hint matches the first token of the +// line, the new line is added at the end of the block containing hint, +// extracting hint into a new block if it is not yet in one. +// +// If the hint is non-nil buts its first token does not match, +// the new line is added after the block containing hint +// (or hint itself, if not in a block). +// +// If no hint is provided, addLine appends the line to the end of +// the last block with a matching first token, +// or to the end of the file if no such block exists. +func (x *FileSyntax) addLine(hint Expr, tokens ...string) *Line { + if hint == nil { + // If no hint given, add to the last statement of the given type. + Loop: + for i := len(x.Stmt) - 1; i >= 0; i-- { + stmt := x.Stmt[i] + switch stmt := stmt.(type) { + case *Line: + if stmt.Token != nil && stmt.Token[0] == tokens[0] { + hint = stmt + break Loop + } + case *LineBlock: + if stmt.Token[0] == tokens[0] { + hint = stmt + break Loop + } + } + } + } + + newLineAfter := func(i int) *Line { + new := &Line{Token: tokens} + if i == len(x.Stmt) { + x.Stmt = append(x.Stmt, new) + } else { + x.Stmt = append(x.Stmt, nil) + copy(x.Stmt[i+2:], x.Stmt[i+1:]) + x.Stmt[i+1] = new + } + return new + } + + if hint != nil { + for i, stmt := range x.Stmt { + switch stmt := stmt.(type) { + case *Line: + if stmt == hint { + if stmt.Token == nil || stmt.Token[0] != tokens[0] { + return newLineAfter(i) + } + + // Convert line to line block. + stmt.InBlock = true + block := &LineBlock{Token: stmt.Token[:1], Line: []*Line{stmt}} + stmt.Token = stmt.Token[1:] + x.Stmt[i] = block + new := &Line{Token: tokens[1:], InBlock: true} + block.Line = append(block.Line, new) + return new + } + + case *LineBlock: + if stmt == hint { + if stmt.Token[0] != tokens[0] { + return newLineAfter(i) + } + + new := &Line{Token: tokens[1:], InBlock: true} + stmt.Line = append(stmt.Line, new) + return new + } + + for j, line := range stmt.Line { + if line == hint { + if stmt.Token[0] != tokens[0] { + return newLineAfter(i) + } + + // Add new line after hint within the block. + stmt.Line = append(stmt.Line, nil) + copy(stmt.Line[j+2:], stmt.Line[j+1:]) + new := &Line{Token: tokens[1:], InBlock: true} + stmt.Line[j+1] = new + return new + } + } + } + } + } + + new := &Line{Token: tokens} + x.Stmt = append(x.Stmt, new) + return new +} + +func (x *FileSyntax) updateLine(line *Line, tokens ...string) { + if line.InBlock { + tokens = tokens[1:] + } + line.Token = tokens +} + +// markRemoved modifies line so that it (and its end-of-line comment, if any) +// will be dropped by (*FileSyntax).Cleanup. +func (line *Line) markRemoved() { + line.Token = nil + line.Comments.Suffix = nil +} + +// Cleanup cleans up the file syntax x after any edit operations. +// To avoid quadratic behavior, (*Line).markRemoved marks the line as dead +// by setting line.Token = nil but does not remove it from the slice +// in which it appears. After edits have all been indicated, +// calling Cleanup cleans out the dead lines. +func (x *FileSyntax) Cleanup() { + w := 0 + for _, stmt := range x.Stmt { + switch stmt := stmt.(type) { + case *Line: + if stmt.Token == nil { + continue + } + case *LineBlock: + ww := 0 + for _, line := range stmt.Line { + if line.Token != nil { + stmt.Line[ww] = line + ww++ + } + } + if ww == 0 { + continue + } + if ww == 1 { + // Collapse block into single line. + line := &Line{ + Comments: Comments{ + Before: commentsAdd(stmt.Before, stmt.Line[0].Before), + Suffix: commentsAdd(stmt.Line[0].Suffix, stmt.Suffix), + After: commentsAdd(stmt.Line[0].After, stmt.After), + }, + Token: stringsAdd(stmt.Token, stmt.Line[0].Token), + } + x.Stmt[w] = line + w++ + continue + } + stmt.Line = stmt.Line[:ww] + } + x.Stmt[w] = stmt + w++ + } + x.Stmt = x.Stmt[:w] +} + +func commentsAdd(x, y []Comment) []Comment { + return append(x[:len(x):len(x)], y...) +} + +func stringsAdd(x, y []string) []string { + return append(x[:len(x):len(x)], y...) +} + +// A CommentBlock represents a top-level block of comments separate +// from any rule. +type CommentBlock struct { + Comments + Start Position +} + +func (x *CommentBlock) Span() (start, end Position) { + return x.Start, x.Start +} + +// A Line is a single line of tokens. +type Line struct { + Comments + Start Position + Token []string + InBlock bool + End Position +} + +func (x *Line) Span() (start, end Position) { + return x.Start, x.End +} + +// A LineBlock is a factored block of lines, like +// +// require ( +// "x" +// "y" +// ) +type LineBlock struct { + Comments + Start Position + LParen LParen + Token []string + Line []*Line + RParen RParen +} + +func (x *LineBlock) Span() (start, end Position) { + return x.Start, x.RParen.Pos.add(")") +} + +// An LParen represents the beginning of a parenthesized line block. +// It is a place to store suffix comments. +type LParen struct { + Comments + Pos Position +} + +func (x *LParen) Span() (start, end Position) { + return x.Pos, x.Pos.add(")") +} + +// An RParen represents the end of a parenthesized line block. +// It is a place to store whole-line (before) comments. +type RParen struct { + Comments + Pos Position +} + +func (x *RParen) Span() (start, end Position) { + return x.Pos, x.Pos.add(")") +} + +// An input represents a single input file being parsed. +type input struct { + // Lexing state. + filename string // name of input file, for errors + complete []byte // entire input + remaining []byte // remaining input + tokenStart []byte // token being scanned to end of input + token token // next token to be returned by lex, peek + pos Position // current input position + comments []Comment // accumulated comments + + // Parser state. + file *FileSyntax // returned top-level syntax tree + parseErrors ErrorList // errors encountered during parsing + + // Comment assignment state. + pre []Expr // all expressions, in preorder traversal + post []Expr // all expressions, in postorder traversal +} + +func newInput(filename string, data []byte) *input { + return &input{ + filename: filename, + complete: data, + remaining: data, + pos: Position{Line: 1, LineRune: 1, Byte: 0}, + } +} + +// parse parses the input file. +func parse(file string, data []byte) (f *FileSyntax, err error) { + // The parser panics for both routine errors like syntax errors + // and for programmer bugs like array index errors. + // Turn both into error returns. Catching bug panics is + // especially important when processing many files. + in := newInput(file, data) + defer func() { + if e := recover(); e != nil && e != &in.parseErrors { + in.parseErrors = append(in.parseErrors, Error{ + Filename: in.filename, + Pos: in.pos, + Err: fmt.Errorf("internal error: %v", e), + }) + } + if err == nil && len(in.parseErrors) > 0 { + err = in.parseErrors + } + }() + + // Prime the lexer by reading in the first token. It will be available + // in the next peek() or lex() call. + in.readToken() + + // Invoke the parser. + in.parseFile() + if len(in.parseErrors) > 0 { + return nil, in.parseErrors + } + in.file.Name = in.filename + + // Assign comments to nearby syntax. + in.assignComments() + + return in.file, nil +} + +// Error is called to report an error. +// Error does not return: it panics. +func (in *input) Error(s string) { + in.parseErrors = append(in.parseErrors, Error{ + Filename: in.filename, + Pos: in.pos, + Err: errors.New(s), + }) + panic(&in.parseErrors) +} + +// eof reports whether the input has reached end of file. +func (in *input) eof() bool { + return len(in.remaining) == 0 +} + +// peekRune returns the next rune in the input without consuming it. +func (in *input) peekRune() int { + if len(in.remaining) == 0 { + return 0 + } + r, _ := utf8.DecodeRune(in.remaining) + return int(r) +} + +// peekPrefix reports whether the remaining input begins with the given prefix. +func (in *input) peekPrefix(prefix string) bool { + // This is like bytes.HasPrefix(in.remaining, []byte(prefix)) + // but without the allocation of the []byte copy of prefix. + for i := 0; i < len(prefix); i++ { + if i >= len(in.remaining) || in.remaining[i] != prefix[i] { + return false + } + } + return true +} + +// readRune consumes and returns the next rune in the input. +func (in *input) readRune() int { + if len(in.remaining) == 0 { + in.Error("internal lexer error: readRune at EOF") + } + r, size := utf8.DecodeRune(in.remaining) + in.remaining = in.remaining[size:] + if r == '\n' { + in.pos.Line++ + in.pos.LineRune = 1 + } else { + in.pos.LineRune++ + } + in.pos.Byte += size + return int(r) +} + +type token struct { + kind tokenKind + pos Position + endPos Position + text string +} + +type tokenKind int + +const ( + _EOF tokenKind = -(iota + 1) + _EOLCOMMENT + _IDENT + _STRING + _COMMENT + + // newlines and punctuation tokens are allowed as ASCII codes. +) + +func (k tokenKind) isComment() bool { + return k == _COMMENT || k == _EOLCOMMENT +} + +// isEOL returns whether a token terminates a line. +func (k tokenKind) isEOL() bool { + return k == _EOF || k == _EOLCOMMENT || k == '\n' +} + +// startToken marks the beginning of the next input token. +// It must be followed by a call to endToken, once the token's text has +// been consumed using readRune. +func (in *input) startToken() { + in.tokenStart = in.remaining + in.token.text = "" + in.token.pos = in.pos +} + +// endToken marks the end of an input token. +// It records the actual token string in tok.text. +// A single trailing newline (LF or CRLF) will be removed from comment tokens. +func (in *input) endToken(kind tokenKind) { + in.token.kind = kind + text := string(in.tokenStart[:len(in.tokenStart)-len(in.remaining)]) + if kind.isComment() { + if strings.HasSuffix(text, "\r\n") { + text = text[:len(text)-2] + } else { + text = strings.TrimSuffix(text, "\n") + } + } + in.token.text = text + in.token.endPos = in.pos +} + +// peek returns the kind of the next token returned by lex. +func (in *input) peek() tokenKind { + return in.token.kind +} + +// lex is called from the parser to obtain the next input token. +func (in *input) lex() token { + tok := in.token + in.readToken() + return tok +} + +// readToken lexes the next token from the text and stores it in in.token. +func (in *input) readToken() { + // Skip past spaces, stopping at non-space or EOF. + for !in.eof() { + c := in.peekRune() + if c == ' ' || c == '\t' || c == '\r' { + in.readRune() + continue + } + + // Comment runs to end of line. + if in.peekPrefix("//") { + in.startToken() + + // Is this comment the only thing on its line? + // Find the last \n before this // and see if it's all + // spaces from there to here. + i := bytes.LastIndex(in.complete[:in.pos.Byte], []byte("\n")) + suffix := len(bytes.TrimSpace(in.complete[i+1:in.pos.Byte])) > 0 + in.readRune() + in.readRune() + + // Consume comment. + for len(in.remaining) > 0 && in.readRune() != '\n' { + } + + // If we are at top level (not in a statement), hand the comment to + // the parser as a _COMMENT token. The grammar is written + // to handle top-level comments itself. + if !suffix { + in.endToken(_COMMENT) + return + } + + // Otherwise, save comment for later attachment to syntax tree. + in.endToken(_EOLCOMMENT) + in.comments = append(in.comments, Comment{in.token.pos, in.token.text, suffix}) + return + } + + if in.peekPrefix("/*") { + in.Error("mod files must use // comments (not /* */ comments)") + } + + // Found non-space non-comment. + break + } + + // Found the beginning of the next token. + in.startToken() + + // End of file. + if in.eof() { + in.endToken(_EOF) + return + } + + // Punctuation tokens. + switch c := in.peekRune(); c { + case '\n', '(', ')', '[', ']', '{', '}', ',': + in.readRune() + in.endToken(tokenKind(c)) + return + + case '"', '`': // quoted string + quote := c + in.readRune() + for { + if in.eof() { + in.pos = in.token.pos + in.Error("unexpected EOF in string") + } + if in.peekRune() == '\n' { + in.Error("unexpected newline in string") + } + c := in.readRune() + if c == quote { + break + } + if c == '\\' && quote != '`' { + if in.eof() { + in.pos = in.token.pos + in.Error("unexpected EOF in string") + } + in.readRune() + } + } + in.endToken(_STRING) + return + } + + // Checked all punctuation. Must be identifier token. + if c := in.peekRune(); !isIdent(c) { + in.Error(fmt.Sprintf("unexpected input character %#q", c)) + } + + // Scan over identifier. + for isIdent(in.peekRune()) { + if in.peekPrefix("//") { + break + } + if in.peekPrefix("/*") { + in.Error("mod files must use // comments (not /* */ comments)") + } + in.readRune() + } + in.endToken(_IDENT) +} + +// isIdent reports whether c is an identifier rune. +// We treat most printable runes as identifier runes, except for a handful of +// ASCII punctuation characters. +func isIdent(c int) bool { + switch r := rune(c); r { + case ' ', '(', ')', '[', ']', '{', '}', ',': + return false + default: + return !unicode.IsSpace(r) && unicode.IsPrint(r) + } +} + +// Comment assignment. +// We build two lists of all subexpressions, preorder and postorder. +// The preorder list is ordered by start location, with outer expressions first. +// The postorder list is ordered by end location, with outer expressions last. +// We use the preorder list to assign each whole-line comment to the syntax +// immediately following it, and we use the postorder list to assign each +// end-of-line comment to the syntax immediately preceding it. + +// order walks the expression adding it and its subexpressions to the +// preorder and postorder lists. +func (in *input) order(x Expr) { + if x != nil { + in.pre = append(in.pre, x) + } + switch x := x.(type) { + default: + panic(fmt.Errorf("order: unexpected type %T", x)) + case nil: + // nothing + case *LParen, *RParen: + // nothing + case *CommentBlock: + // nothing + case *Line: + // nothing + case *FileSyntax: + for _, stmt := range x.Stmt { + in.order(stmt) + } + case *LineBlock: + in.order(&x.LParen) + for _, l := range x.Line { + in.order(l) + } + in.order(&x.RParen) + } + if x != nil { + in.post = append(in.post, x) + } +} + +// assignComments attaches comments to nearby syntax. +func (in *input) assignComments() { + const debug = false + + // Generate preorder and postorder lists. + in.order(in.file) + + // Split into whole-line comments and suffix comments. + var line, suffix []Comment + for _, com := range in.comments { + if com.Suffix { + suffix = append(suffix, com) + } else { + line = append(line, com) + } + } + + if debug { + for _, c := range line { + fmt.Fprintf(os.Stderr, "LINE %q :%d:%d #%d\n", c.Token, c.Start.Line, c.Start.LineRune, c.Start.Byte) + } + } + + // Assign line comments to syntax immediately following. + for _, x := range in.pre { + start, _ := x.Span() + if debug { + fmt.Fprintf(os.Stderr, "pre %T :%d:%d #%d\n", x, start.Line, start.LineRune, start.Byte) + } + xcom := x.Comment() + for len(line) > 0 && start.Byte >= line[0].Start.Byte { + if debug { + fmt.Fprintf(os.Stderr, "ASSIGN LINE %q #%d\n", line[0].Token, line[0].Start.Byte) + } + xcom.Before = append(xcom.Before, line[0]) + line = line[1:] + } + } + + // Remaining line comments go at end of file. + in.file.After = append(in.file.After, line...) + + if debug { + for _, c := range suffix { + fmt.Fprintf(os.Stderr, "SUFFIX %q :%d:%d #%d\n", c.Token, c.Start.Line, c.Start.LineRune, c.Start.Byte) + } + } + + // Assign suffix comments to syntax immediately before. + for i := len(in.post) - 1; i >= 0; i-- { + x := in.post[i] + + start, end := x.Span() + if debug { + fmt.Fprintf(os.Stderr, "post %T :%d:%d #%d :%d:%d #%d\n", x, start.Line, start.LineRune, start.Byte, end.Line, end.LineRune, end.Byte) + } + + // Do not assign suffix comments to end of line block or whole file. + // Instead assign them to the last element inside. + switch x.(type) { + case *FileSyntax: + continue + } + + // Do not assign suffix comments to something that starts + // on an earlier line, so that in + // + // x ( y + // z ) // comment + // + // we assign the comment to z and not to x ( ... ). + if start.Line != end.Line { + continue + } + xcom := x.Comment() + for len(suffix) > 0 && end.Byte <= suffix[len(suffix)-1].Start.Byte { + if debug { + fmt.Fprintf(os.Stderr, "ASSIGN SUFFIX %q #%d\n", suffix[len(suffix)-1].Token, suffix[len(suffix)-1].Start.Byte) + } + xcom.Suffix = append(xcom.Suffix, suffix[len(suffix)-1]) + suffix = suffix[:len(suffix)-1] + } + } + + // We assigned suffix comments in reverse. + // If multiple suffix comments were appended to the same + // expression node, they are now in reverse. Fix that. + for _, x := range in.post { + reverseComments(x.Comment().Suffix) + } + + // Remaining suffix comments go at beginning of file. + in.file.Before = append(in.file.Before, suffix...) +} + +// reverseComments reverses the []Comment list. +func reverseComments(list []Comment) { + for i, j := 0, len(list)-1; i < j; i, j = i+1, j-1 { + list[i], list[j] = list[j], list[i] + } +} + +func (in *input) parseFile() { + in.file = new(FileSyntax) + var cb *CommentBlock + for { + switch in.peek() { + case '\n': + in.lex() + if cb != nil { + in.file.Stmt = append(in.file.Stmt, cb) + cb = nil + } + case _COMMENT: + tok := in.lex() + if cb == nil { + cb = &CommentBlock{Start: tok.pos} + } + com := cb.Comment() + com.Before = append(com.Before, Comment{Start: tok.pos, Token: tok.text}) + case _EOF: + if cb != nil { + in.file.Stmt = append(in.file.Stmt, cb) + } + return + default: + in.parseStmt() + if cb != nil { + in.file.Stmt[len(in.file.Stmt)-1].Comment().Before = cb.Before + cb = nil + } + } + } +} + +func (in *input) parseStmt() { + tok := in.lex() + start := tok.pos + end := tok.endPos + tokens := []string{tok.text} + for { + tok := in.lex() + switch { + case tok.kind.isEOL(): + in.file.Stmt = append(in.file.Stmt, &Line{ + Start: start, + Token: tokens, + End: end, + }) + return + + case tok.kind == '(': + if next := in.peek(); next.isEOL() { + // Start of block: no more tokens on this line. + in.file.Stmt = append(in.file.Stmt, in.parseLineBlock(start, tokens, tok)) + return + } else if next == ')' { + rparen := in.lex() + if in.peek().isEOL() { + // Empty block. + in.lex() + in.file.Stmt = append(in.file.Stmt, &LineBlock{ + Start: start, + Token: tokens, + LParen: LParen{Pos: tok.pos}, + RParen: RParen{Pos: rparen.pos}, + }) + return + } + // '( )' in the middle of the line, not a block. + tokens = append(tokens, tok.text, rparen.text) + } else { + // '(' in the middle of the line, not a block. + tokens = append(tokens, tok.text) + } + + default: + tokens = append(tokens, tok.text) + end = tok.endPos + } + } +} + +func (in *input) parseLineBlock(start Position, token []string, lparen token) *LineBlock { + x := &LineBlock{ + Start: start, + Token: token, + LParen: LParen{Pos: lparen.pos}, + } + var comments []Comment + for { + switch in.peek() { + case _EOLCOMMENT: + // Suffix comment, will be attached later by assignComments. + in.lex() + case '\n': + // Blank line. Add an empty comment to preserve it. + in.lex() + if len(comments) == 0 && len(x.Line) > 0 || len(comments) > 0 && comments[len(comments)-1].Token != "" { + comments = append(comments, Comment{}) + } + case _COMMENT: + tok := in.lex() + comments = append(comments, Comment{Start: tok.pos, Token: tok.text}) + case _EOF: + in.Error(fmt.Sprintf("syntax error (unterminated block started at %s:%d:%d)", in.filename, x.Start.Line, x.Start.LineRune)) + case ')': + rparen := in.lex() + x.RParen.Before = comments + x.RParen.Pos = rparen.pos + if !in.peek().isEOL() { + in.Error("syntax error (expected newline after closing paren)") + } + in.lex() + return x + default: + l := in.parseLine() + x.Line = append(x.Line, l) + l.Comment().Before = comments + comments = nil + } + } +} + +func (in *input) parseLine() *Line { + tok := in.lex() + if tok.kind.isEOL() { + in.Error("internal parse error: parseLine at end of line") + } + start := tok.pos + end := tok.endPos + tokens := []string{tok.text} + for { + tok := in.lex() + if tok.kind.isEOL() { + return &Line{ + Start: start, + Token: tokens, + End: end, + InBlock: true, + } + } + tokens = append(tokens, tok.text) + end = tok.endPos + } +} + +var ( + slashSlash = []byte("//") + moduleStr = []byte("module") +) + +// ModulePath returns the module path from the gomod file text. +// If it cannot find a module path, it returns an empty string. +// It is tolerant of unrelated problems in the go.mod file. +func ModulePath(mod []byte) string { + for len(mod) > 0 { + line := mod + mod = nil + if i := bytes.IndexByte(line, '\n'); i >= 0 { + line, mod = line[:i], line[i+1:] + } + if i := bytes.Index(line, slashSlash); i >= 0 { + line = line[:i] + } + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, moduleStr) { + continue + } + line = line[len(moduleStr):] + n := len(line) + line = bytes.TrimSpace(line) + if len(line) == n || len(line) == 0 { + continue + } + + if line[0] == '"' || line[0] == '`' { + p, err := strconv.Unquote(string(line)) + if err != nil { + return "" // malformed quoted string or multiline module path + } + return p + } + + return string(line) + } + return "" // missing module path +} diff --git a/vendor/golang.org/x/mod/modfile/rule.go b/vendor/golang.org/x/mod/modfile/rule.go new file mode 100644 index 00000000000..6bcde8fabe3 --- /dev/null +++ b/vendor/golang.org/x/mod/modfile/rule.go @@ -0,0 +1,1559 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package modfile implements a parser and formatter for go.mod files. +// +// The go.mod syntax is described in +// https://golang.org/cmd/go/#hdr-The_go_mod_file. +// +// The Parse and ParseLax functions both parse a go.mod file and return an +// abstract syntax tree. ParseLax ignores unknown statements and may be used to +// parse go.mod files that may have been developed with newer versions of Go. +// +// The File struct returned by Parse and ParseLax represent an abstract +// go.mod file. File has several methods like AddNewRequire and DropReplace +// that can be used to programmatically edit a file. +// +// The Format function formats a File back to a byte slice which can be +// written to a file. +package modfile + +import ( + "errors" + "fmt" + "path/filepath" + "sort" + "strconv" + "strings" + "unicode" + + "golang.org/x/mod/internal/lazyregexp" + "golang.org/x/mod/module" + "golang.org/x/mod/semver" +) + +// A File is the parsed, interpreted form of a go.mod file. +type File struct { + Module *Module + Go *Go + Require []*Require + Exclude []*Exclude + Replace []*Replace + Retract []*Retract + + Syntax *FileSyntax +} + +// A Module is the module statement. +type Module struct { + Mod module.Version + Deprecated string + Syntax *Line +} + +// A Go is the go statement. +type Go struct { + Version string // "1.23" + Syntax *Line +} + +// An Exclude is a single exclude statement. +type Exclude struct { + Mod module.Version + Syntax *Line +} + +// A Replace is a single replace statement. +type Replace struct { + Old module.Version + New module.Version + Syntax *Line +} + +// A Retract is a single retract statement. +type Retract struct { + VersionInterval + Rationale string + Syntax *Line +} + +// A VersionInterval represents a range of versions with upper and lower bounds. +// Intervals are closed: both bounds are included. When Low is equal to High, +// the interval may refer to a single version ('v1.2.3') or an interval +// ('[v1.2.3, v1.2.3]'); both have the same representation. +type VersionInterval struct { + Low, High string +} + +// A Require is a single require statement. +type Require struct { + Mod module.Version + Indirect bool // has "// indirect" comment + Syntax *Line +} + +func (r *Require) markRemoved() { + r.Syntax.markRemoved() + *r = Require{} +} + +func (r *Require) setVersion(v string) { + r.Mod.Version = v + + if line := r.Syntax; len(line.Token) > 0 { + if line.InBlock { + // If the line is preceded by an empty line, remove it; see + // https://golang.org/issue/33779. + if len(line.Comments.Before) == 1 && len(line.Comments.Before[0].Token) == 0 { + line.Comments.Before = line.Comments.Before[:0] + } + if len(line.Token) >= 2 { // example.com v1.2.3 + line.Token[1] = v + } + } else { + if len(line.Token) >= 3 { // require example.com v1.2.3 + line.Token[2] = v + } + } + } +} + +// setIndirect sets line to have (or not have) a "// indirect" comment. +func (r *Require) setIndirect(indirect bool) { + r.Indirect = indirect + line := r.Syntax + if isIndirect(line) == indirect { + return + } + if indirect { + // Adding comment. + if len(line.Suffix) == 0 { + // New comment. + line.Suffix = []Comment{{Token: "// indirect", Suffix: true}} + return + } + + com := &line.Suffix[0] + text := strings.TrimSpace(strings.TrimPrefix(com.Token, string(slashSlash))) + if text == "" { + // Empty comment. + com.Token = "// indirect" + return + } + + // Insert at beginning of existing comment. + com.Token = "// indirect; " + text + return + } + + // Removing comment. + f := strings.TrimSpace(strings.TrimPrefix(line.Suffix[0].Token, string(slashSlash))) + if f == "indirect" { + // Remove whole comment. + line.Suffix = nil + return + } + + // Remove comment prefix. + com := &line.Suffix[0] + i := strings.Index(com.Token, "indirect;") + com.Token = "//" + com.Token[i+len("indirect;"):] +} + +// isIndirect reports whether line has a "// indirect" comment, +// meaning it is in go.mod only for its effect on indirect dependencies, +// so that it can be dropped entirely once the effective version of the +// indirect dependency reaches the given minimum version. +func isIndirect(line *Line) bool { + if len(line.Suffix) == 0 { + return false + } + f := strings.Fields(strings.TrimPrefix(line.Suffix[0].Token, string(slashSlash))) + return (len(f) == 1 && f[0] == "indirect" || len(f) > 1 && f[0] == "indirect;") +} + +func (f *File) AddModuleStmt(path string) error { + if f.Syntax == nil { + f.Syntax = new(FileSyntax) + } + if f.Module == nil { + f.Module = &Module{ + Mod: module.Version{Path: path}, + Syntax: f.Syntax.addLine(nil, "module", AutoQuote(path)), + } + } else { + f.Module.Mod.Path = path + f.Syntax.updateLine(f.Module.Syntax, "module", AutoQuote(path)) + } + return nil +} + +func (f *File) AddComment(text string) { + if f.Syntax == nil { + f.Syntax = new(FileSyntax) + } + f.Syntax.Stmt = append(f.Syntax.Stmt, &CommentBlock{ + Comments: Comments{ + Before: []Comment{ + { + Token: text, + }, + }, + }, + }) +} + +type VersionFixer func(path, version string) (string, error) + +// errDontFix is returned by a VersionFixer to indicate the version should be +// left alone, even if it's not canonical. +var dontFixRetract VersionFixer = func(_, vers string) (string, error) { + return vers, nil +} + +// Parse parses and returns a go.mod file. +// +// file is the name of the file, used in positions and errors. +// +// data is the content of the file. +// +// fix is an optional function that canonicalizes module versions. +// If fix is nil, all module versions must be canonical (module.CanonicalVersion +// must return the same string). +func Parse(file string, data []byte, fix VersionFixer) (*File, error) { + return parseToFile(file, data, fix, true) +} + +// ParseLax is like Parse but ignores unknown statements. +// It is used when parsing go.mod files other than the main module, +// under the theory that most statement types we add in the future will +// only apply in the main module, like exclude and replace, +// and so we get better gradual deployments if old go commands +// simply ignore those statements when found in go.mod files +// in dependencies. +func ParseLax(file string, data []byte, fix VersionFixer) (*File, error) { + return parseToFile(file, data, fix, false) +} + +func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (parsed *File, err error) { + fs, err := parse(file, data) + if err != nil { + return nil, err + } + f := &File{ + Syntax: fs, + } + var errs ErrorList + + // fix versions in retract directives after the file is parsed. + // We need the module path to fix versions, and it might be at the end. + defer func() { + oldLen := len(errs) + f.fixRetract(fix, &errs) + if len(errs) > oldLen { + parsed, err = nil, errs + } + }() + + for _, x := range fs.Stmt { + switch x := x.(type) { + case *Line: + f.add(&errs, nil, x, x.Token[0], x.Token[1:], fix, strict) + + case *LineBlock: + if len(x.Token) > 1 { + if strict { + errs = append(errs, Error{ + Filename: file, + Pos: x.Start, + Err: fmt.Errorf("unknown block type: %s", strings.Join(x.Token, " ")), + }) + } + continue + } + switch x.Token[0] { + default: + if strict { + errs = append(errs, Error{ + Filename: file, + Pos: x.Start, + Err: fmt.Errorf("unknown block type: %s", strings.Join(x.Token, " ")), + }) + } + continue + case "module", "require", "exclude", "replace", "retract": + for _, l := range x.Line { + f.add(&errs, x, l, x.Token[0], l.Token, fix, strict) + } + } + } + } + + if len(errs) > 0 { + return nil, errs + } + return f, nil +} + +var GoVersionRE = lazyregexp.New(`^([1-9][0-9]*)\.(0|[1-9][0-9]*)$`) +var laxGoVersionRE = lazyregexp.New(`^v?(([1-9][0-9]*)\.(0|[1-9][0-9]*))([^0-9].*)$`) + +func (f *File) add(errs *ErrorList, block *LineBlock, line *Line, verb string, args []string, fix VersionFixer, strict bool) { + // If strict is false, this module is a dependency. + // We ignore all unknown directives as well as main-module-only + // directives like replace and exclude. It will work better for + // forward compatibility if we can depend on modules that have unknown + // statements (presumed relevant only when acting as the main module) + // and simply ignore those statements. + if !strict { + switch verb { + case "go", "module", "retract", "require": + // want these even for dependency go.mods + default: + return + } + } + + wrapModPathError := func(modPath string, err error) { + *errs = append(*errs, Error{ + Filename: f.Syntax.Name, + Pos: line.Start, + ModPath: modPath, + Verb: verb, + Err: err, + }) + } + wrapError := func(err error) { + *errs = append(*errs, Error{ + Filename: f.Syntax.Name, + Pos: line.Start, + Err: err, + }) + } + errorf := func(format string, args ...interface{}) { + wrapError(fmt.Errorf(format, args...)) + } + + switch verb { + default: + errorf("unknown directive: %s", verb) + + case "go": + if f.Go != nil { + errorf("repeated go statement") + return + } + if len(args) != 1 { + errorf("go directive expects exactly one argument") + return + } else if !GoVersionRE.MatchString(args[0]) { + fixed := false + if !strict { + if m := laxGoVersionRE.FindStringSubmatch(args[0]); m != nil { + args[0] = m[1] + fixed = true + } + } + if !fixed { + errorf("invalid go version '%s': must match format 1.23", args[0]) + return + } + } + + f.Go = &Go{Syntax: line} + f.Go.Version = args[0] + + case "module": + if f.Module != nil { + errorf("repeated module statement") + return + } + deprecated := parseDeprecation(block, line) + f.Module = &Module{ + Syntax: line, + Deprecated: deprecated, + } + if len(args) != 1 { + errorf("usage: module module/path") + return + } + s, err := parseString(&args[0]) + if err != nil { + errorf("invalid quoted string: %v", err) + return + } + f.Module.Mod = module.Version{Path: s} + + case "require", "exclude": + if len(args) != 2 { + errorf("usage: %s module/path v1.2.3", verb) + return + } + s, err := parseString(&args[0]) + if err != nil { + errorf("invalid quoted string: %v", err) + return + } + v, err := parseVersion(verb, s, &args[1], fix) + if err != nil { + wrapError(err) + return + } + pathMajor, err := modulePathMajor(s) + if err != nil { + wrapError(err) + return + } + if err := module.CheckPathMajor(v, pathMajor); err != nil { + wrapModPathError(s, err) + return + } + if verb == "require" { + f.Require = append(f.Require, &Require{ + Mod: module.Version{Path: s, Version: v}, + Syntax: line, + Indirect: isIndirect(line), + }) + } else { + f.Exclude = append(f.Exclude, &Exclude{ + Mod: module.Version{Path: s, Version: v}, + Syntax: line, + }) + } + + case "replace": + replace, wrappederr := parseReplace(f.Syntax.Name, line, verb, args, fix) + if wrappederr != nil { + *errs = append(*errs, *wrappederr) + return + } + f.Replace = append(f.Replace, replace) + + case "retract": + rationale := parseDirectiveComment(block, line) + vi, err := parseVersionInterval(verb, "", &args, dontFixRetract) + if err != nil { + if strict { + wrapError(err) + return + } else { + // Only report errors parsing intervals in the main module. We may + // support additional syntax in the future, such as open and half-open + // intervals. Those can't be supported now, because they break the + // go.mod parser, even in lax mode. + return + } + } + if len(args) > 0 && strict { + // In the future, there may be additional information after the version. + errorf("unexpected token after version: %q", args[0]) + return + } + retract := &Retract{ + VersionInterval: vi, + Rationale: rationale, + Syntax: line, + } + f.Retract = append(f.Retract, retract) + } +} + +func parseReplace(filename string, line *Line, verb string, args []string, fix VersionFixer) (*Replace, *Error) { + wrapModPathError := func(modPath string, err error) *Error { + return &Error{ + Filename: filename, + Pos: line.Start, + ModPath: modPath, + Verb: verb, + Err: err, + } + } + wrapError := func(err error) *Error { + return &Error{ + Filename: filename, + Pos: line.Start, + Err: err, + } + } + errorf := func(format string, args ...interface{}) *Error { + return wrapError(fmt.Errorf(format, args...)) + } + + arrow := 2 + if len(args) >= 2 && args[1] == "=>" { + arrow = 1 + } + if len(args) < arrow+2 || len(args) > arrow+3 || args[arrow] != "=>" { + return nil, errorf("usage: %s module/path [v1.2.3] => other/module v1.4\n\t or %s module/path [v1.2.3] => ../local/directory", verb, verb) + } + s, err := parseString(&args[0]) + if err != nil { + return nil, errorf("invalid quoted string: %v", err) + } + pathMajor, err := modulePathMajor(s) + if err != nil { + return nil, wrapModPathError(s, err) + + } + var v string + if arrow == 2 { + v, err = parseVersion(verb, s, &args[1], fix) + if err != nil { + return nil, wrapError(err) + } + if err := module.CheckPathMajor(v, pathMajor); err != nil { + return nil, wrapModPathError(s, err) + } + } + ns, err := parseString(&args[arrow+1]) + if err != nil { + return nil, errorf("invalid quoted string: %v", err) + } + nv := "" + if len(args) == arrow+2 { + if !IsDirectoryPath(ns) { + if strings.Contains(ns, "@") { + return nil, errorf("replacement module must match format 'path version', not 'path@version'") + } + return nil, errorf("replacement module without version must be directory path (rooted or starting with ./ or ../)") + } + if filepath.Separator == '/' && strings.Contains(ns, `\`) { + return nil, errorf("replacement directory appears to be Windows path (on a non-windows system)") + } + } + if len(args) == arrow+3 { + nv, err = parseVersion(verb, ns, &args[arrow+2], fix) + if err != nil { + return nil, wrapError(err) + } + if IsDirectoryPath(ns) { + return nil, errorf("replacement module directory path %q cannot have version", ns) + + } + } + return &Replace{ + Old: module.Version{Path: s, Version: v}, + New: module.Version{Path: ns, Version: nv}, + Syntax: line, + }, nil +} + +// fixRetract applies fix to each retract directive in f, appending any errors +// to errs. +// +// Most versions are fixed as we parse the file, but for retract directives, +// the relevant module path is the one specified with the module directive, +// and that might appear at the end of the file (or not at all). +func (f *File) fixRetract(fix VersionFixer, errs *ErrorList) { + if fix == nil { + return + } + path := "" + if f.Module != nil { + path = f.Module.Mod.Path + } + var r *Retract + wrapError := func(err error) { + *errs = append(*errs, Error{ + Filename: f.Syntax.Name, + Pos: r.Syntax.Start, + Err: err, + }) + } + + for _, r = range f.Retract { + if path == "" { + wrapError(errors.New("no module directive found, so retract cannot be used")) + return // only print the first one of these + } + + args := r.Syntax.Token + if args[0] == "retract" { + args = args[1:] + } + vi, err := parseVersionInterval("retract", path, &args, fix) + if err != nil { + wrapError(err) + } + r.VersionInterval = vi + } +} + +func (f *WorkFile) add(errs *ErrorList, line *Line, verb string, args []string, fix VersionFixer) { + wrapError := func(err error) { + *errs = append(*errs, Error{ + Filename: f.Syntax.Name, + Pos: line.Start, + Err: err, + }) + } + errorf := func(format string, args ...interface{}) { + wrapError(fmt.Errorf(format, args...)) + } + + switch verb { + default: + errorf("unknown directive: %s", verb) + + case "go": + if f.Go != nil { + errorf("repeated go statement") + return + } + if len(args) != 1 { + errorf("go directive expects exactly one argument") + return + } else if !GoVersionRE.MatchString(args[0]) { + errorf("invalid go version '%s': must match format 1.23", args[0]) + return + } + + f.Go = &Go{Syntax: line} + f.Go.Version = args[0] + + case "use": + if len(args) != 1 { + errorf("usage: %s local/dir", verb) + return + } + s, err := parseString(&args[0]) + if err != nil { + errorf("invalid quoted string: %v", err) + return + } + f.Use = append(f.Use, &Use{ + Path: s, + Syntax: line, + }) + + case "replace": + replace, wrappederr := parseReplace(f.Syntax.Name, line, verb, args, fix) + if wrappederr != nil { + *errs = append(*errs, *wrappederr) + return + } + f.Replace = append(f.Replace, replace) + } +} + +// IsDirectoryPath reports whether the given path should be interpreted +// as a directory path. Just like on the go command line, relative paths +// and rooted paths are directory paths; the rest are module paths. +func IsDirectoryPath(ns string) bool { + // Because go.mod files can move from one system to another, + // we check all known path syntaxes, both Unix and Windows. + return strings.HasPrefix(ns, "./") || strings.HasPrefix(ns, "../") || strings.HasPrefix(ns, "/") || + strings.HasPrefix(ns, `.\`) || strings.HasPrefix(ns, `..\`) || strings.HasPrefix(ns, `\`) || + len(ns) >= 2 && ('A' <= ns[0] && ns[0] <= 'Z' || 'a' <= ns[0] && ns[0] <= 'z') && ns[1] == ':' +} + +// MustQuote reports whether s must be quoted in order to appear as +// a single token in a go.mod line. +func MustQuote(s string) bool { + for _, r := range s { + switch r { + case ' ', '"', '\'', '`': + return true + + case '(', ')', '[', ']', '{', '}', ',': + if len(s) > 1 { + return true + } + + default: + if !unicode.IsPrint(r) { + return true + } + } + } + return s == "" || strings.Contains(s, "//") || strings.Contains(s, "/*") +} + +// AutoQuote returns s or, if quoting is required for s to appear in a go.mod, +// the quotation of s. +func AutoQuote(s string) string { + if MustQuote(s) { + return strconv.Quote(s) + } + return s +} + +func parseVersionInterval(verb string, path string, args *[]string, fix VersionFixer) (VersionInterval, error) { + toks := *args + if len(toks) == 0 || toks[0] == "(" { + return VersionInterval{}, fmt.Errorf("expected '[' or version") + } + if toks[0] != "[" { + v, err := parseVersion(verb, path, &toks[0], fix) + if err != nil { + return VersionInterval{}, err + } + *args = toks[1:] + return VersionInterval{Low: v, High: v}, nil + } + toks = toks[1:] + + if len(toks) == 0 { + return VersionInterval{}, fmt.Errorf("expected version after '['") + } + low, err := parseVersion(verb, path, &toks[0], fix) + if err != nil { + return VersionInterval{}, err + } + toks = toks[1:] + + if len(toks) == 0 || toks[0] != "," { + return VersionInterval{}, fmt.Errorf("expected ',' after version") + } + toks = toks[1:] + + if len(toks) == 0 { + return VersionInterval{}, fmt.Errorf("expected version after ','") + } + high, err := parseVersion(verb, path, &toks[0], fix) + if err != nil { + return VersionInterval{}, err + } + toks = toks[1:] + + if len(toks) == 0 || toks[0] != "]" { + return VersionInterval{}, fmt.Errorf("expected ']' after version") + } + toks = toks[1:] + + *args = toks + return VersionInterval{Low: low, High: high}, nil +} + +func parseString(s *string) (string, error) { + t := *s + if strings.HasPrefix(t, `"`) { + var err error + if t, err = strconv.Unquote(t); err != nil { + return "", err + } + } else if strings.ContainsAny(t, "\"'`") { + // Other quotes are reserved both for possible future expansion + // and to avoid confusion. For example if someone types 'x' + // we want that to be a syntax error and not a literal x in literal quotation marks. + return "", fmt.Errorf("unquoted string cannot contain quote") + } + *s = AutoQuote(t) + return t, nil +} + +var deprecatedRE = lazyregexp.New(`(?s)(?:^|\n\n)Deprecated: *(.*?)(?:$|\n\n)`) + +// parseDeprecation extracts the text of comments on a "module" directive and +// extracts a deprecation message from that. +// +// A deprecation message is contained in a paragraph within a block of comments +// that starts with "Deprecated:" (case sensitive). The message runs until the +// end of the paragraph and does not include the "Deprecated:" prefix. If the +// comment block has multiple paragraphs that start with "Deprecated:", +// parseDeprecation returns the message from the first. +func parseDeprecation(block *LineBlock, line *Line) string { + text := parseDirectiveComment(block, line) + m := deprecatedRE.FindStringSubmatch(text) + if m == nil { + return "" + } + return m[1] +} + +// parseDirectiveComment extracts the text of comments on a directive. +// If the directive's line does not have comments and is part of a block that +// does have comments, the block's comments are used. +func parseDirectiveComment(block *LineBlock, line *Line) string { + comments := line.Comment() + if block != nil && len(comments.Before) == 0 && len(comments.Suffix) == 0 { + comments = block.Comment() + } + groups := [][]Comment{comments.Before, comments.Suffix} + var lines []string + for _, g := range groups { + for _, c := range g { + if !strings.HasPrefix(c.Token, "//") { + continue // blank line + } + lines = append(lines, strings.TrimSpace(strings.TrimPrefix(c.Token, "//"))) + } + } + return strings.Join(lines, "\n") +} + +type ErrorList []Error + +func (e ErrorList) Error() string { + errStrs := make([]string, len(e)) + for i, err := range e { + errStrs[i] = err.Error() + } + return strings.Join(errStrs, "\n") +} + +type Error struct { + Filename string + Pos Position + Verb string + ModPath string + Err error +} + +func (e *Error) Error() string { + var pos string + if e.Pos.LineRune > 1 { + // Don't print LineRune if it's 1 (beginning of line). + // It's always 1 except in scanner errors, which are rare. + pos = fmt.Sprintf("%s:%d:%d: ", e.Filename, e.Pos.Line, e.Pos.LineRune) + } else if e.Pos.Line > 0 { + pos = fmt.Sprintf("%s:%d: ", e.Filename, e.Pos.Line) + } else if e.Filename != "" { + pos = fmt.Sprintf("%s: ", e.Filename) + } + + var directive string + if e.ModPath != "" { + directive = fmt.Sprintf("%s %s: ", e.Verb, e.ModPath) + } else if e.Verb != "" { + directive = fmt.Sprintf("%s: ", e.Verb) + } + + return pos + directive + e.Err.Error() +} + +func (e *Error) Unwrap() error { return e.Err } + +func parseVersion(verb string, path string, s *string, fix VersionFixer) (string, error) { + t, err := parseString(s) + if err != nil { + return "", &Error{ + Verb: verb, + ModPath: path, + Err: &module.InvalidVersionError{ + Version: *s, + Err: err, + }, + } + } + if fix != nil { + fixed, err := fix(path, t) + if err != nil { + if err, ok := err.(*module.ModuleError); ok { + return "", &Error{ + Verb: verb, + ModPath: path, + Err: err.Err, + } + } + return "", err + } + t = fixed + } else { + cv := module.CanonicalVersion(t) + if cv == "" { + return "", &Error{ + Verb: verb, + ModPath: path, + Err: &module.InvalidVersionError{ + Version: t, + Err: errors.New("must be of the form v1.2.3"), + }, + } + } + t = cv + } + *s = t + return *s, nil +} + +func modulePathMajor(path string) (string, error) { + _, major, ok := module.SplitPathVersion(path) + if !ok { + return "", fmt.Errorf("invalid module path") + } + return major, nil +} + +func (f *File) Format() ([]byte, error) { + return Format(f.Syntax), nil +} + +// Cleanup cleans up the file f after any edit operations. +// To avoid quadratic behavior, modifications like DropRequire +// clear the entry but do not remove it from the slice. +// Cleanup cleans out all the cleared entries. +func (f *File) Cleanup() { + w := 0 + for _, r := range f.Require { + if r.Mod.Path != "" { + f.Require[w] = r + w++ + } + } + f.Require = f.Require[:w] + + w = 0 + for _, x := range f.Exclude { + if x.Mod.Path != "" { + f.Exclude[w] = x + w++ + } + } + f.Exclude = f.Exclude[:w] + + w = 0 + for _, r := range f.Replace { + if r.Old.Path != "" { + f.Replace[w] = r + w++ + } + } + f.Replace = f.Replace[:w] + + w = 0 + for _, r := range f.Retract { + if r.Low != "" || r.High != "" { + f.Retract[w] = r + w++ + } + } + f.Retract = f.Retract[:w] + + f.Syntax.Cleanup() +} + +func (f *File) AddGoStmt(version string) error { + if !GoVersionRE.MatchString(version) { + return fmt.Errorf("invalid language version string %q", version) + } + if f.Go == nil { + var hint Expr + if f.Module != nil && f.Module.Syntax != nil { + hint = f.Module.Syntax + } + f.Go = &Go{ + Version: version, + Syntax: f.Syntax.addLine(hint, "go", version), + } + } else { + f.Go.Version = version + f.Syntax.updateLine(f.Go.Syntax, "go", version) + } + return nil +} + +// AddRequire sets the first require line for path to version vers, +// preserving any existing comments for that line and removing all +// other lines for path. +// +// If no line currently exists for path, AddRequire adds a new line +// at the end of the last require block. +func (f *File) AddRequire(path, vers string) error { + need := true + for _, r := range f.Require { + if r.Mod.Path == path { + if need { + r.Mod.Version = vers + f.Syntax.updateLine(r.Syntax, "require", AutoQuote(path), vers) + need = false + } else { + r.Syntax.markRemoved() + *r = Require{} + } + } + } + + if need { + f.AddNewRequire(path, vers, false) + } + return nil +} + +// AddNewRequire adds a new require line for path at version vers at the end of +// the last require block, regardless of any existing require lines for path. +func (f *File) AddNewRequire(path, vers string, indirect bool) { + line := f.Syntax.addLine(nil, "require", AutoQuote(path), vers) + r := &Require{ + Mod: module.Version{Path: path, Version: vers}, + Syntax: line, + } + r.setIndirect(indirect) + f.Require = append(f.Require, r) +} + +// SetRequire updates the requirements of f to contain exactly req, preserving +// the existing block structure and line comment contents (except for 'indirect' +// markings) for the first requirement on each named module path. +// +// The Syntax field is ignored for the requirements in req. +// +// Any requirements not already present in the file are added to the block +// containing the last require line. +// +// The requirements in req must specify at most one distinct version for each +// module path. +// +// If any existing requirements may be removed, the caller should call Cleanup +// after all edits are complete. +func (f *File) SetRequire(req []*Require) { + type elem struct { + version string + indirect bool + } + need := make(map[string]elem) + for _, r := range req { + if prev, dup := need[r.Mod.Path]; dup && prev.version != r.Mod.Version { + panic(fmt.Errorf("SetRequire called with conflicting versions for path %s (%s and %s)", r.Mod.Path, prev.version, r.Mod.Version)) + } + need[r.Mod.Path] = elem{r.Mod.Version, r.Indirect} + } + + // Update or delete the existing Require entries to preserve + // only the first for each module path in req. + for _, r := range f.Require { + e, ok := need[r.Mod.Path] + if ok { + r.setVersion(e.version) + r.setIndirect(e.indirect) + } else { + r.markRemoved() + } + delete(need, r.Mod.Path) + } + + // Add new entries in the last block of the file for any paths that weren't + // already present. + // + // This step is nondeterministic, but the final result will be deterministic + // because we will sort the block. + for path, e := range need { + f.AddNewRequire(path, e.version, e.indirect) + } + + f.SortBlocks() +} + +// SetRequireSeparateIndirect updates the requirements of f to contain the given +// requirements. Comment contents (except for 'indirect' markings) are retained +// from the first existing requirement for each module path. Like SetRequire, +// SetRequireSeparateIndirect adds requirements for new paths in req, +// updates the version and "// indirect" comment on existing requirements, +// and deletes requirements on paths not in req. Existing duplicate requirements +// are deleted. +// +// As its name suggests, SetRequireSeparateIndirect puts direct and indirect +// requirements into two separate blocks, one containing only direct +// requirements, and the other containing only indirect requirements. +// SetRequireSeparateIndirect may move requirements between these two blocks +// when their indirect markings change. However, SetRequireSeparateIndirect +// won't move requirements from other blocks, especially blocks with comments. +// +// If the file initially has one uncommented block of requirements, +// SetRequireSeparateIndirect will split it into a direct-only and indirect-only +// block. This aids in the transition to separate blocks. +func (f *File) SetRequireSeparateIndirect(req []*Require) { + // hasComments returns whether a line or block has comments + // other than "indirect". + hasComments := func(c Comments) bool { + return len(c.Before) > 0 || len(c.After) > 0 || len(c.Suffix) > 1 || + (len(c.Suffix) == 1 && + strings.TrimSpace(strings.TrimPrefix(c.Suffix[0].Token, string(slashSlash))) != "indirect") + } + + // moveReq adds r to block. If r was in another block, moveReq deletes + // it from that block and transfers its comments. + moveReq := func(r *Require, block *LineBlock) { + var line *Line + if r.Syntax == nil { + line = &Line{Token: []string{AutoQuote(r.Mod.Path), r.Mod.Version}} + r.Syntax = line + if r.Indirect { + r.setIndirect(true) + } + } else { + line = new(Line) + *line = *r.Syntax + if !line.InBlock && len(line.Token) > 0 && line.Token[0] == "require" { + line.Token = line.Token[1:] + } + r.Syntax.Token = nil // Cleanup will delete the old line. + r.Syntax = line + } + line.InBlock = true + block.Line = append(block.Line, line) + } + + // Examine existing require lines and blocks. + var ( + // We may insert new requirements into the last uncommented + // direct-only and indirect-only blocks. We may also move requirements + // to the opposite block if their indirect markings change. + lastDirectIndex = -1 + lastIndirectIndex = -1 + + // If there are no direct-only or indirect-only blocks, a new block may + // be inserted after the last require line or block. + lastRequireIndex = -1 + + // If there's only one require line or block, and it's uncommented, + // we'll move its requirements to the direct-only or indirect-only blocks. + requireLineOrBlockCount = 0 + + // Track the block each requirement belongs to (if any) so we can + // move them later. + lineToBlock = make(map[*Line]*LineBlock) + ) + for i, stmt := range f.Syntax.Stmt { + switch stmt := stmt.(type) { + case *Line: + if len(stmt.Token) == 0 || stmt.Token[0] != "require" { + continue + } + lastRequireIndex = i + requireLineOrBlockCount++ + if !hasComments(stmt.Comments) { + if isIndirect(stmt) { + lastIndirectIndex = i + } else { + lastDirectIndex = i + } + } + + case *LineBlock: + if len(stmt.Token) == 0 || stmt.Token[0] != "require" { + continue + } + lastRequireIndex = i + requireLineOrBlockCount++ + allDirect := len(stmt.Line) > 0 && !hasComments(stmt.Comments) + allIndirect := len(stmt.Line) > 0 && !hasComments(stmt.Comments) + for _, line := range stmt.Line { + lineToBlock[line] = stmt + if hasComments(line.Comments) { + allDirect = false + allIndirect = false + } else if isIndirect(line) { + allDirect = false + } else { + allIndirect = false + } + } + if allDirect { + lastDirectIndex = i + } + if allIndirect { + lastIndirectIndex = i + } + } + } + + oneFlatUncommentedBlock := requireLineOrBlockCount == 1 && + !hasComments(*f.Syntax.Stmt[lastRequireIndex].Comment()) + + // Create direct and indirect blocks if needed. Convert lines into blocks + // if needed. If we end up with an empty block or a one-line block, + // Cleanup will delete it or convert it to a line later. + insertBlock := func(i int) *LineBlock { + block := &LineBlock{Token: []string{"require"}} + f.Syntax.Stmt = append(f.Syntax.Stmt, nil) + copy(f.Syntax.Stmt[i+1:], f.Syntax.Stmt[i:]) + f.Syntax.Stmt[i] = block + return block + } + + ensureBlock := func(i int) *LineBlock { + switch stmt := f.Syntax.Stmt[i].(type) { + case *LineBlock: + return stmt + case *Line: + block := &LineBlock{ + Token: []string{"require"}, + Line: []*Line{stmt}, + } + stmt.Token = stmt.Token[1:] // remove "require" + stmt.InBlock = true + f.Syntax.Stmt[i] = block + return block + default: + panic(fmt.Sprintf("unexpected statement: %v", stmt)) + } + } + + var lastDirectBlock *LineBlock + if lastDirectIndex < 0 { + if lastIndirectIndex >= 0 { + lastDirectIndex = lastIndirectIndex + lastIndirectIndex++ + } else if lastRequireIndex >= 0 { + lastDirectIndex = lastRequireIndex + 1 + } else { + lastDirectIndex = len(f.Syntax.Stmt) + } + lastDirectBlock = insertBlock(lastDirectIndex) + } else { + lastDirectBlock = ensureBlock(lastDirectIndex) + } + + var lastIndirectBlock *LineBlock + if lastIndirectIndex < 0 { + lastIndirectIndex = lastDirectIndex + 1 + lastIndirectBlock = insertBlock(lastIndirectIndex) + } else { + lastIndirectBlock = ensureBlock(lastIndirectIndex) + } + + // Delete requirements we don't want anymore. + // Update versions and indirect comments on requirements we want to keep. + // If a requirement is in last{Direct,Indirect}Block with the wrong + // indirect marking after this, or if the requirement is in an single + // uncommented mixed block (oneFlatUncommentedBlock), move it to the + // correct block. + // + // Some blocks may be empty after this. Cleanup will remove them. + need := make(map[string]*Require) + for _, r := range req { + need[r.Mod.Path] = r + } + have := make(map[string]*Require) + for _, r := range f.Require { + path := r.Mod.Path + if need[path] == nil || have[path] != nil { + // Requirement not needed, or duplicate requirement. Delete. + r.markRemoved() + continue + } + have[r.Mod.Path] = r + r.setVersion(need[path].Mod.Version) + r.setIndirect(need[path].Indirect) + if need[path].Indirect && + (oneFlatUncommentedBlock || lineToBlock[r.Syntax] == lastDirectBlock) { + moveReq(r, lastIndirectBlock) + } else if !need[path].Indirect && + (oneFlatUncommentedBlock || lineToBlock[r.Syntax] == lastIndirectBlock) { + moveReq(r, lastDirectBlock) + } + } + + // Add new requirements. + for path, r := range need { + if have[path] == nil { + if r.Indirect { + moveReq(r, lastIndirectBlock) + } else { + moveReq(r, lastDirectBlock) + } + f.Require = append(f.Require, r) + } + } + + f.SortBlocks() +} + +func (f *File) DropRequire(path string) error { + for _, r := range f.Require { + if r.Mod.Path == path { + r.Syntax.markRemoved() + *r = Require{} + } + } + return nil +} + +// AddExclude adds a exclude statement to the mod file. Errors if the provided +// version is not a canonical version string +func (f *File) AddExclude(path, vers string) error { + if err := checkCanonicalVersion(path, vers); err != nil { + return err + } + + var hint *Line + for _, x := range f.Exclude { + if x.Mod.Path == path && x.Mod.Version == vers { + return nil + } + if x.Mod.Path == path { + hint = x.Syntax + } + } + + f.Exclude = append(f.Exclude, &Exclude{Mod: module.Version{Path: path, Version: vers}, Syntax: f.Syntax.addLine(hint, "exclude", AutoQuote(path), vers)}) + return nil +} + +func (f *File) DropExclude(path, vers string) error { + for _, x := range f.Exclude { + if x.Mod.Path == path && x.Mod.Version == vers { + x.Syntax.markRemoved() + *x = Exclude{} + } + } + return nil +} + +func (f *File) AddReplace(oldPath, oldVers, newPath, newVers string) error { + return addReplace(f.Syntax, &f.Replace, oldPath, oldVers, newPath, newVers) +} + +func addReplace(syntax *FileSyntax, replace *[]*Replace, oldPath, oldVers, newPath, newVers string) error { + need := true + old := module.Version{Path: oldPath, Version: oldVers} + new := module.Version{Path: newPath, Version: newVers} + tokens := []string{"replace", AutoQuote(oldPath)} + if oldVers != "" { + tokens = append(tokens, oldVers) + } + tokens = append(tokens, "=>", AutoQuote(newPath)) + if newVers != "" { + tokens = append(tokens, newVers) + } + + var hint *Line + for _, r := range *replace { + if r.Old.Path == oldPath && (oldVers == "" || r.Old.Version == oldVers) { + if need { + // Found replacement for old; update to use new. + r.New = new + syntax.updateLine(r.Syntax, tokens...) + need = false + continue + } + // Already added; delete other replacements for same. + r.Syntax.markRemoved() + *r = Replace{} + } + if r.Old.Path == oldPath { + hint = r.Syntax + } + } + if need { + *replace = append(*replace, &Replace{Old: old, New: new, Syntax: syntax.addLine(hint, tokens...)}) + } + return nil +} + +func (f *File) DropReplace(oldPath, oldVers string) error { + for _, r := range f.Replace { + if r.Old.Path == oldPath && r.Old.Version == oldVers { + r.Syntax.markRemoved() + *r = Replace{} + } + } + return nil +} + +// AddRetract adds a retract statement to the mod file. Errors if the provided +// version interval does not consist of canonical version strings +func (f *File) AddRetract(vi VersionInterval, rationale string) error { + var path string + if f.Module != nil { + path = f.Module.Mod.Path + } + if err := checkCanonicalVersion(path, vi.High); err != nil { + return err + } + if err := checkCanonicalVersion(path, vi.Low); err != nil { + return err + } + + r := &Retract{ + VersionInterval: vi, + } + if vi.Low == vi.High { + r.Syntax = f.Syntax.addLine(nil, "retract", AutoQuote(vi.Low)) + } else { + r.Syntax = f.Syntax.addLine(nil, "retract", "[", AutoQuote(vi.Low), ",", AutoQuote(vi.High), "]") + } + if rationale != "" { + for _, line := range strings.Split(rationale, "\n") { + com := Comment{Token: "// " + line} + r.Syntax.Comment().Before = append(r.Syntax.Comment().Before, com) + } + } + return nil +} + +func (f *File) DropRetract(vi VersionInterval) error { + for _, r := range f.Retract { + if r.VersionInterval == vi { + r.Syntax.markRemoved() + *r = Retract{} + } + } + return nil +} + +func (f *File) SortBlocks() { + f.removeDups() // otherwise sorting is unsafe + + for _, stmt := range f.Syntax.Stmt { + block, ok := stmt.(*LineBlock) + if !ok { + continue + } + less := lineLess + if block.Token[0] == "retract" { + less = lineRetractLess + } + sort.SliceStable(block.Line, func(i, j int) bool { + return less(block.Line[i], block.Line[j]) + }) + } +} + +// removeDups removes duplicate exclude and replace directives. +// +// Earlier exclude directives take priority. +// +// Later replace directives take priority. +// +// require directives are not de-duplicated. That's left up to higher-level +// logic (MVS). +// +// retract directives are not de-duplicated since comments are +// meaningful, and versions may be retracted multiple times. +func (f *File) removeDups() { + removeDups(f.Syntax, &f.Exclude, &f.Replace) +} + +func removeDups(syntax *FileSyntax, exclude *[]*Exclude, replace *[]*Replace) { + kill := make(map[*Line]bool) + + // Remove duplicate excludes. + if exclude != nil { + haveExclude := make(map[module.Version]bool) + for _, x := range *exclude { + if haveExclude[x.Mod] { + kill[x.Syntax] = true + continue + } + haveExclude[x.Mod] = true + } + var excl []*Exclude + for _, x := range *exclude { + if !kill[x.Syntax] { + excl = append(excl, x) + } + } + *exclude = excl + } + + // Remove duplicate replacements. + // Later replacements take priority over earlier ones. + haveReplace := make(map[module.Version]bool) + for i := len(*replace) - 1; i >= 0; i-- { + x := (*replace)[i] + if haveReplace[x.Old] { + kill[x.Syntax] = true + continue + } + haveReplace[x.Old] = true + } + var repl []*Replace + for _, x := range *replace { + if !kill[x.Syntax] { + repl = append(repl, x) + } + } + *replace = repl + + // Duplicate require and retract directives are not removed. + + // Drop killed statements from the syntax tree. + var stmts []Expr + for _, stmt := range syntax.Stmt { + switch stmt := stmt.(type) { + case *Line: + if kill[stmt] { + continue + } + case *LineBlock: + var lines []*Line + for _, line := range stmt.Line { + if !kill[line] { + lines = append(lines, line) + } + } + stmt.Line = lines + if len(lines) == 0 { + continue + } + } + stmts = append(stmts, stmt) + } + syntax.Stmt = stmts +} + +// lineLess returns whether li should be sorted before lj. It sorts +// lexicographically without assigning any special meaning to tokens. +func lineLess(li, lj *Line) bool { + for k := 0; k < len(li.Token) && k < len(lj.Token); k++ { + if li.Token[k] != lj.Token[k] { + return li.Token[k] < lj.Token[k] + } + } + return len(li.Token) < len(lj.Token) +} + +// lineRetractLess returns whether li should be sorted before lj for lines in +// a "retract" block. It treats each line as a version interval. Single versions +// are compared as if they were intervals with the same low and high version. +// Intervals are sorted in descending order, first by low version, then by +// high version, using semver.Compare. +func lineRetractLess(li, lj *Line) bool { + interval := func(l *Line) VersionInterval { + if len(l.Token) == 1 { + return VersionInterval{Low: l.Token[0], High: l.Token[0]} + } else if len(l.Token) == 5 && l.Token[0] == "[" && l.Token[2] == "," && l.Token[4] == "]" { + return VersionInterval{Low: l.Token[1], High: l.Token[3]} + } else { + // Line in unknown format. Treat as an invalid version. + return VersionInterval{} + } + } + vii := interval(li) + vij := interval(lj) + if cmp := semver.Compare(vii.Low, vij.Low); cmp != 0 { + return cmp > 0 + } + return semver.Compare(vii.High, vij.High) > 0 +} + +// checkCanonicalVersion returns a non-nil error if vers is not a canonical +// version string or does not match the major version of path. +// +// If path is non-empty, the error text suggests a format with a major version +// corresponding to the path. +func checkCanonicalVersion(path, vers string) error { + _, pathMajor, pathMajorOk := module.SplitPathVersion(path) + + if vers == "" || vers != module.CanonicalVersion(vers) { + if pathMajor == "" { + return &module.InvalidVersionError{ + Version: vers, + Err: fmt.Errorf("must be of the form v1.2.3"), + } + } + return &module.InvalidVersionError{ + Version: vers, + Err: fmt.Errorf("must be of the form %s.2.3", module.PathMajorPrefix(pathMajor)), + } + } + + if pathMajorOk { + if err := module.CheckPathMajor(vers, pathMajor); err != nil { + if pathMajor == "" { + // In this context, the user probably wrote "v2.3.4" when they meant + // "v2.3.4+incompatible". Suggest that instead of "v0 or v1". + return &module.InvalidVersionError{ + Version: vers, + Err: fmt.Errorf("should be %s+incompatible (or module %s/%v)", vers, path, semver.Major(vers)), + } + } + return err + } + } + + return nil +} diff --git a/vendor/golang.org/x/mod/modfile/work.go b/vendor/golang.org/x/mod/modfile/work.go new file mode 100644 index 00000000000..0c0e521525a --- /dev/null +++ b/vendor/golang.org/x/mod/modfile/work.go @@ -0,0 +1,234 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package modfile + +import ( + "fmt" + "sort" + "strings" +) + +// A WorkFile is the parsed, interpreted form of a go.work file. +type WorkFile struct { + Go *Go + Use []*Use + Replace []*Replace + + Syntax *FileSyntax +} + +// A Use is a single directory statement. +type Use struct { + Path string // Use path of module. + ModulePath string // Module path in the comment. + Syntax *Line +} + +// ParseWork parses and returns a go.work file. +// +// file is the name of the file, used in positions and errors. +// +// data is the content of the file. +// +// fix is an optional function that canonicalizes module versions. +// If fix is nil, all module versions must be canonical (module.CanonicalVersion +// must return the same string). +func ParseWork(file string, data []byte, fix VersionFixer) (*WorkFile, error) { + fs, err := parse(file, data) + if err != nil { + return nil, err + } + f := &WorkFile{ + Syntax: fs, + } + var errs ErrorList + + for _, x := range fs.Stmt { + switch x := x.(type) { + case *Line: + f.add(&errs, x, x.Token[0], x.Token[1:], fix) + + case *LineBlock: + if len(x.Token) > 1 { + errs = append(errs, Error{ + Filename: file, + Pos: x.Start, + Err: fmt.Errorf("unknown block type: %s", strings.Join(x.Token, " ")), + }) + continue + } + switch x.Token[0] { + default: + errs = append(errs, Error{ + Filename: file, + Pos: x.Start, + Err: fmt.Errorf("unknown block type: %s", strings.Join(x.Token, " ")), + }) + continue + case "use", "replace": + for _, l := range x.Line { + f.add(&errs, l, x.Token[0], l.Token, fix) + } + } + } + } + + if len(errs) > 0 { + return nil, errs + } + return f, nil +} + +// Cleanup cleans up the file f after any edit operations. +// To avoid quadratic behavior, modifications like DropRequire +// clear the entry but do not remove it from the slice. +// Cleanup cleans out all the cleared entries. +func (f *WorkFile) Cleanup() { + w := 0 + for _, r := range f.Use { + if r.Path != "" { + f.Use[w] = r + w++ + } + } + f.Use = f.Use[:w] + + w = 0 + for _, r := range f.Replace { + if r.Old.Path != "" { + f.Replace[w] = r + w++ + } + } + f.Replace = f.Replace[:w] + + f.Syntax.Cleanup() +} + +func (f *WorkFile) AddGoStmt(version string) error { + if !GoVersionRE.MatchString(version) { + return fmt.Errorf("invalid language version string %q", version) + } + if f.Go == nil { + stmt := &Line{Token: []string{"go", version}} + f.Go = &Go{ + Version: version, + Syntax: stmt, + } + // Find the first non-comment-only block that's and add + // the go statement before it. That will keep file comments at the top. + i := 0 + for i = 0; i < len(f.Syntax.Stmt); i++ { + if _, ok := f.Syntax.Stmt[i].(*CommentBlock); !ok { + break + } + } + f.Syntax.Stmt = append(append(f.Syntax.Stmt[:i:i], stmt), f.Syntax.Stmt[i:]...) + } else { + f.Go.Version = version + f.Syntax.updateLine(f.Go.Syntax, "go", version) + } + return nil +} + +func (f *WorkFile) AddUse(diskPath, modulePath string) error { + need := true + for _, d := range f.Use { + if d.Path == diskPath { + if need { + d.ModulePath = modulePath + f.Syntax.updateLine(d.Syntax, "use", AutoQuote(diskPath)) + need = false + } else { + d.Syntax.markRemoved() + *d = Use{} + } + } + } + + if need { + f.AddNewUse(diskPath, modulePath) + } + return nil +} + +func (f *WorkFile) AddNewUse(diskPath, modulePath string) { + line := f.Syntax.addLine(nil, "use", AutoQuote(diskPath)) + f.Use = append(f.Use, &Use{Path: diskPath, ModulePath: modulePath, Syntax: line}) +} + +func (f *WorkFile) SetUse(dirs []*Use) { + need := make(map[string]string) + for _, d := range dirs { + need[d.Path] = d.ModulePath + } + + for _, d := range f.Use { + if modulePath, ok := need[d.Path]; ok { + d.ModulePath = modulePath + } else { + d.Syntax.markRemoved() + *d = Use{} + } + } + + // TODO(#45713): Add module path to comment. + + for diskPath, modulePath := range need { + f.AddNewUse(diskPath, modulePath) + } + f.SortBlocks() +} + +func (f *WorkFile) DropUse(path string) error { + for _, d := range f.Use { + if d.Path == path { + d.Syntax.markRemoved() + *d = Use{} + } + } + return nil +} + +func (f *WorkFile) AddReplace(oldPath, oldVers, newPath, newVers string) error { + return addReplace(f.Syntax, &f.Replace, oldPath, oldVers, newPath, newVers) +} + +func (f *WorkFile) DropReplace(oldPath, oldVers string) error { + for _, r := range f.Replace { + if r.Old.Path == oldPath && r.Old.Version == oldVers { + r.Syntax.markRemoved() + *r = Replace{} + } + } + return nil +} + +func (f *WorkFile) SortBlocks() { + f.removeDups() // otherwise sorting is unsafe + + for _, stmt := range f.Syntax.Stmt { + block, ok := stmt.(*LineBlock) + if !ok { + continue + } + sort.SliceStable(block.Line, func(i, j int) bool { + return lineLess(block.Line[i], block.Line[j]) + }) + } +} + +// removeDups removes duplicate replace directives. +// +// Later replace directives take priority. +// +// require directives are not de-duplicated. That's left up to higher-level +// logic (MVS). +// +// retract directives are not de-duplicated since comments are +// meaningful, and versions may be retracted multiple times. +func (f *WorkFile) removeDups() { + removeDups(f.Syntax, nil, &f.Replace) +} diff --git a/vendor/gopkg.in/tomb.v1/LICENSE b/vendor/gopkg.in/tomb.v1/LICENSE deleted file mode 100644 index a4249bb31dd..00000000000 --- a/vendor/gopkg.in/tomb.v1/LICENSE +++ /dev/null @@ -1,29 +0,0 @@ -tomb - support for clean goroutine termination in Go. - -Copyright (c) 2010-2011 - Gustavo Niemeyer - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/tomb.v1/README.md b/vendor/gopkg.in/tomb.v1/README.md deleted file mode 100644 index 3ae8788e813..00000000000 --- a/vendor/gopkg.in/tomb.v1/README.md +++ /dev/null @@ -1,4 +0,0 @@ -Installation and usage ----------------------- - -See [gopkg.in/tomb.v1](https://gopkg.in/tomb.v1) for documentation and usage details. diff --git a/vendor/gopkg.in/tomb.v1/tomb.go b/vendor/gopkg.in/tomb.v1/tomb.go deleted file mode 100644 index 9aec56d821d..00000000000 --- a/vendor/gopkg.in/tomb.v1/tomb.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) 2011 - Gustavo Niemeyer -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// * Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// * Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// The tomb package offers a conventional API for clean goroutine termination. -// -// A Tomb tracks the lifecycle of a goroutine as alive, dying or dead, -// and the reason for its death. -// -// The zero value of a Tomb assumes that a goroutine is about to be -// created or already alive. Once Kill or Killf is called with an -// argument that informs the reason for death, the goroutine is in -// a dying state and is expected to terminate soon. Right before the -// goroutine function or method returns, Done must be called to inform -// that the goroutine is indeed dead and about to stop running. -// -// A Tomb exposes Dying and Dead channels. These channels are closed -// when the Tomb state changes in the respective way. They enable -// explicit blocking until the state changes, and also to selectively -// unblock select statements accordingly. -// -// When the tomb state changes to dying and there's still logic going -// on within the goroutine, nested functions and methods may choose to -// return ErrDying as their error value, as this error won't alter the -// tomb state if provided to the Kill method. This is a convenient way to -// follow standard Go practices in the context of a dying tomb. -// -// For background and a detailed example, see the following blog post: -// -// http://blog.labix.org/2011/10/09/death-of-goroutines-under-control -// -// For a more complex code snippet demonstrating the use of multiple -// goroutines with a single Tomb, see: -// -// http://play.golang.org/p/Xh7qWsDPZP -// -package tomb - -import ( - "errors" - "fmt" - "sync" -) - -// A Tomb tracks the lifecycle of a goroutine as alive, dying or dead, -// and the reason for its death. -// -// See the package documentation for details. -type Tomb struct { - m sync.Mutex - dying chan struct{} - dead chan struct{} - reason error -} - -var ( - ErrStillAlive = errors.New("tomb: still alive") - ErrDying = errors.New("tomb: dying") -) - -func (t *Tomb) init() { - t.m.Lock() - if t.dead == nil { - t.dead = make(chan struct{}) - t.dying = make(chan struct{}) - t.reason = ErrStillAlive - } - t.m.Unlock() -} - -// Dead returns the channel that can be used to wait -// until t.Done has been called. -func (t *Tomb) Dead() <-chan struct{} { - t.init() - return t.dead -} - -// Dying returns the channel that can be used to wait -// until t.Kill or t.Done has been called. -func (t *Tomb) Dying() <-chan struct{} { - t.init() - return t.dying -} - -// Wait blocks until the goroutine is in a dead state and returns the -// reason for its death. -func (t *Tomb) Wait() error { - t.init() - <-t.dead - t.m.Lock() - reason := t.reason - t.m.Unlock() - return reason -} - -// Done flags the goroutine as dead, and should be called a single time -// right before the goroutine function or method returns. -// If the goroutine was not already in a dying state before Done is -// called, it will be flagged as dying and dead at once with no -// error. -func (t *Tomb) Done() { - t.Kill(nil) - close(t.dead) -} - -// Kill flags the goroutine as dying for the given reason. -// Kill may be called multiple times, but only the first -// non-nil error is recorded as the reason for termination. -// -// If reason is ErrDying, the previous reason isn't replaced -// even if it is nil. It's a runtime error to call Kill with -// ErrDying if t is not in a dying state. -func (t *Tomb) Kill(reason error) { - t.init() - t.m.Lock() - defer t.m.Unlock() - if reason == ErrDying { - if t.reason == ErrStillAlive { - panic("tomb: Kill with ErrDying while still alive") - } - return - } - if t.reason == nil || t.reason == ErrStillAlive { - t.reason = reason - } - // If the receive on t.dying succeeds, then - // it can only be because we have already closed it. - // If it blocks, then we know that it needs to be closed. - select { - case <-t.dying: - default: - close(t.dying) - } -} - -// Killf works like Kill, but builds the reason providing the received -// arguments to fmt.Errorf. The generated error is also returned. -func (t *Tomb) Killf(f string, a ...interface{}) error { - err := fmt.Errorf(f, a...) - t.Kill(err) - return err -} - -// Err returns the reason for the goroutine death provided via Kill -// or Killf, or ErrStillAlive when the goroutine is still alive. -func (t *Tomb) Err() (reason error) { - t.init() - t.m.Lock() - reason = t.reason - t.m.Unlock() - return -} diff --git a/vendor/modules.txt b/vendor/modules.txt index af758a7b4bf..7fcce6e6c74 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -14,12 +14,6 @@ github.com/certifi/gocertifi # github.com/cespare/xxhash/v2 v2.1.2 ## explicit; go 1.11 github.com/cespare/xxhash/v2 -# github.com/cheekybits/genny v1.0.0 -## explicit -github.com/cheekybits/genny -github.com/cheekybits/genny/generic -github.com/cheekybits/genny/out -github.com/cheekybits/genny/parse # github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93 ## explicit; go 1.12 github.com/cloudflare/brotli-go @@ -160,6 +154,10 @@ github.com/gobwas/ws/wsutil # github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 ## explicit github.com/golang-collections/collections/queue +# github.com/golang/mock v1.6.0 +## explicit; go 1.11 +github.com/golang/mock/mockgen +github.com/golang/mock/mockgen/model # github.com/golang/protobuf v1.5.2 ## explicit; go 1.9 github.com/golang/protobuf/jsonpb @@ -172,6 +170,9 @@ github.com/golang/protobuf/ptypes/timestamp ## explicit; go 1.12 github.com/google/gopacket github.com/google/gopacket/layers +# github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 +## explicit; go 1.14 +github.com/google/pprof/profile # github.com/google/uuid v1.3.0 ## explicit github.com/google/uuid @@ -196,33 +197,6 @@ github.com/klauspost/compress/flate ## explicit # github.com/kylelemons/godebug v1.1.0 ## explicit; go 1.11 -# github.com/lucas-clemente/quic-go v0.28.1 => github.com/chungthuang/quic-go v0.27.1-0.20220809135021-ca330f1dec9f -## explicit; go 1.16 -github.com/lucas-clemente/quic-go -github.com/lucas-clemente/quic-go/internal/ackhandler -github.com/lucas-clemente/quic-go/internal/congestion -github.com/lucas-clemente/quic-go/internal/flowcontrol -github.com/lucas-clemente/quic-go/internal/handshake -github.com/lucas-clemente/quic-go/internal/logutils -github.com/lucas-clemente/quic-go/internal/protocol -github.com/lucas-clemente/quic-go/internal/qerr -github.com/lucas-clemente/quic-go/internal/qtls -github.com/lucas-clemente/quic-go/internal/utils -github.com/lucas-clemente/quic-go/internal/wire -github.com/lucas-clemente/quic-go/logging -github.com/lucas-clemente/quic-go/quicvarint -# github.com/marten-seemann/qtls-go1-16 v0.1.5 -## explicit; go 1.16 -github.com/marten-seemann/qtls-go1-16 -# github.com/marten-seemann/qtls-go1-17 v0.1.2 -## explicit; go 1.17 -github.com/marten-seemann/qtls-go1-17 -# github.com/marten-seemann/qtls-go1-18 v0.1.2 => github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e -## explicit; go 1.18 -github.com/marten-seemann/qtls-go1-18 -# github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 => github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e -## explicit; go 1.19 -github.com/marten-seemann/qtls-go1-19 # github.com/mattn/go-colorable v0.1.13 ## explicit; go 1.15 github.com/mattn/go-colorable @@ -244,38 +218,24 @@ github.com/modern-go/concurrent # github.com/modern-go/reflect2 v1.0.2 ## explicit; go 1.12 github.com/modern-go/reflect2 -# github.com/nxadm/tail v1.4.8 -## explicit; go 1.13 -github.com/nxadm/tail -github.com/nxadm/tail/ratelimiter -github.com/nxadm/tail/util -github.com/nxadm/tail/watch -github.com/nxadm/tail/winfile -# github.com/onsi/ginkgo v1.16.5 -## explicit; go 1.16 -github.com/onsi/ginkgo/config -github.com/onsi/ginkgo/formatter -github.com/onsi/ginkgo/ginkgo -github.com/onsi/ginkgo/ginkgo/convert -github.com/onsi/ginkgo/ginkgo/interrupthandler -github.com/onsi/ginkgo/ginkgo/nodot -github.com/onsi/ginkgo/ginkgo/outline -github.com/onsi/ginkgo/ginkgo/testrunner -github.com/onsi/ginkgo/ginkgo/testsuite -github.com/onsi/ginkgo/ginkgo/watch -github.com/onsi/ginkgo/internal/codelocation -github.com/onsi/ginkgo/internal/containernode -github.com/onsi/ginkgo/internal/failer -github.com/onsi/ginkgo/internal/leafnodes -github.com/onsi/ginkgo/internal/remote -github.com/onsi/ginkgo/internal/spec -github.com/onsi/ginkgo/internal/spec_iterator -github.com/onsi/ginkgo/internal/writer -github.com/onsi/ginkgo/reporters -github.com/onsi/ginkgo/reporters/stenographer -github.com/onsi/ginkgo/reporters/stenographer/support/go-colorable -github.com/onsi/ginkgo/reporters/stenographer/support/go-isatty -github.com/onsi/ginkgo/types +# github.com/onsi/ginkgo/v2 v2.4.0 +## explicit; go 1.18 +github.com/onsi/ginkgo/v2/config +github.com/onsi/ginkgo/v2/formatter +github.com/onsi/ginkgo/v2/ginkgo +github.com/onsi/ginkgo/v2/ginkgo/build +github.com/onsi/ginkgo/v2/ginkgo/command +github.com/onsi/ginkgo/v2/ginkgo/generators +github.com/onsi/ginkgo/v2/ginkgo/internal +github.com/onsi/ginkgo/v2/ginkgo/labels +github.com/onsi/ginkgo/v2/ginkgo/outline +github.com/onsi/ginkgo/v2/ginkgo/run +github.com/onsi/ginkgo/v2/ginkgo/unfocus +github.com/onsi/ginkgo/v2/ginkgo/watch +github.com/onsi/ginkgo/v2/internal/interrupt_handler +github.com/onsi/ginkgo/v2/internal/parallel_support +github.com/onsi/ginkgo/v2/reporters +github.com/onsi/ginkgo/v2/types # github.com/onsi/gomega v1.23.0 ## explicit; go 1.18 # github.com/opentracing/opentracing-go v1.2.0 @@ -312,6 +272,28 @@ github.com/prometheus/common/model github.com/prometheus/procfs github.com/prometheus/procfs/internal/fs github.com/prometheus/procfs/internal/util +# github.com/quic-go/qtls-go1-19 v0.3.2 => github.com/cloudflare/qtls-pq v0.0.0-20230320123031-3faac1a945b2 +## explicit; go 1.19 +github.com/quic-go/qtls-go1-19 +# github.com/quic-go/qtls-go1-20 v0.2.2 => github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633 +## explicit; go 1.20 +github.com/quic-go/qtls-go1-20 +# github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000 => github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 +## explicit; go 1.19 +github.com/quic-go/quic-go +github.com/quic-go/quic-go/internal/ackhandler +github.com/quic-go/quic-go/internal/congestion +github.com/quic-go/quic-go/internal/flowcontrol +github.com/quic-go/quic-go/internal/handshake +github.com/quic-go/quic-go/internal/logutils +github.com/quic-go/quic-go/internal/protocol +github.com/quic-go/quic-go/internal/qerr +github.com/quic-go/quic-go/internal/qtls +github.com/quic-go/quic-go/internal/utils +github.com/quic-go/quic-go/internal/utils/linkedlist +github.com/quic-go/quic-go/internal/wire +github.com/quic-go/quic-go/logging +github.com/quic-go/quic-go/quicvarint # github.com/rs/zerolog v1.20.0 ## explicit github.com/rs/zerolog @@ -389,9 +371,13 @@ golang.org/x/crypto/salsa20/salsa golang.org/x/crypto/ssh golang.org/x/crypto/ssh/internal/bcrypt_pbkdf golang.org/x/crypto/ssh/terminal +# golang.org/x/exp v0.0.0-20221205204356-47842c84f3db +## explicit; go 1.18 +golang.org/x/exp/constraints # golang.org/x/mod v0.8.0 ## explicit; go 1.17 golang.org/x/mod/internal/lazyregexp +golang.org/x/mod/modfile golang.org/x/mod/module golang.org/x/mod/semver # golang.org/x/net v0.9.0 @@ -583,9 +569,6 @@ gopkg.in/square/go-jose.v2 gopkg.in/square/go-jose.v2/cipher gopkg.in/square/go-jose.v2/json gopkg.in/square/go-jose.v2/jwt -# gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 -## explicit -gopkg.in/tomb.v1 # gopkg.in/yaml.v2 v2.4.0 ## explicit; go 1.15 gopkg.in/yaml.v2 @@ -616,12 +599,8 @@ zombiezen.com/go/capnproto2/schemas zombiezen.com/go/capnproto2/server zombiezen.com/go/capnproto2/std/capnp/rpc # github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d -# github.com/lucas-clemente/quic-go => github.com/chungthuang/quic-go v0.27.1-0.20220809135021-ca330f1dec9f # github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1 # gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 -# github.com/marten-seemann/qtls-go1-18 => github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e -# github.com/marten-seemann/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e -# github.com/marten-seemann/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230215110727-8b4e1699c2a8 -# github.com/quic-go/qtls-go1-18 => github.com/cloudflare/qtls-pq v0.0.0-20230103171413-e7a2fb559a0e -# github.com/quic-go/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230103171656-05e84f90909e -# github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230215110727-8b4e1699c2a8 +# github.com/quic-go/quic-go => github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 +# github.com/quic-go/qtls-go1-19 => github.com/cloudflare/qtls-pq v0.0.0-20230320123031-3faac1a945b2 +# github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633