Skip to content

Commit

Permalink
Merge WebsocketAcceptor and WebsocketClientFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
trowski committed Oct 8, 2023
1 parent ec8a6f5 commit 15b3ce1
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 166 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use Amp\Http\Server\StaticContent\DocumentRoot;
use Amp\Log\ConsoleFormatter;
use Amp\Log\StreamHandler;
use Amp\Socket;
use Amp\Websocket\Server\AllowOriginAcceptor;
use Amp\Websocket\Server\AllowOriginClientFactory;
use Amp\Websocket\Server\Websocket;
use Amp\Websocket\Server\WebsocketClientGateway;
use Amp\Websocket\Server\WebsocketClientHandler;
Expand All @@ -61,7 +61,7 @@ $server->expose(new Socket\InternetAddress('[::1]', 1337));

$errorHandler = new DefaultErrorHandler();

$handshakeHandler = new AllowOriginAcceptor(
$clientFactory = new AllowOriginClientFactory(
['http://localhost:1337', 'http://127.0.0.1:1337', 'http://[::1]:1337'],
);

Expand All @@ -73,8 +73,8 @@ $clientHandler = new class implements WebsocketClientHandler {

public function handleClient(
WebsocketClient $client,
Request $request,
Response $response
Request $request,
Response $response,
): void {
$this->gateway->addClient($client);

Expand All @@ -88,7 +88,7 @@ $clientHandler = new class implements WebsocketClientHandler {
}
};

$websocket = new Websocket($server, $logger, $handshakeHandler, $clientHandler);
$websocket = new Websocket($server, $logger, $clientHandler, $clientFactory);

$router = new Router($server, new NullLogger(), $errorHandler);
$router->addRoute('GET', '/broadcast', $websocket);
Expand Down
18 changes: 14 additions & 4 deletions examples/broadcast-server/server.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
use Amp\Log\ConsoleFormatter;
use Amp\Log\StreamHandler;
use Amp\Socket;
use Amp\Websocket\Server\AllowOriginAcceptor;
use Amp\Websocket\Compression\Rfc7692CompressionFactory;
use Amp\Websocket\Server\AllowOriginClientFactory;
use Amp\Websocket\Server\Rfc6455ClientFactory;
use Amp\Websocket\Server\Websocket;
use Amp\Websocket\Server\WebsocketClientGateway;
use Amp\Websocket\Server\WebsocketClientHandler;
Expand All @@ -25,7 +27,7 @@
require __DIR__ . '/../../vendor/autoload.php';

$logHandler = new StreamHandler(getStdout());
$logHandler->setFormatter(new ConsoleFormatter);
$logHandler->setFormatter(new ConsoleFormatter());
$logger = new Logger('server');
$logger->pushHandler($logHandler);

Expand All @@ -36,8 +38,11 @@

$errorHandler = new DefaultErrorHandler();

$acceptor = new AllowOriginAcceptor(
$compressionFactory = new Rfc7692CompressionFactory();

$clientFactory = new AllowOriginClientFactory(
['http://localhost:1337', 'http://127.0.0.1:1337', 'http://[::1]:1337'],
clientFactory: new Rfc6455ClientFactory(compressionContextFactory: $compressionFactory),
);

$clientHandler = new class implements WebsocketClientHandler {
Expand All @@ -56,7 +61,12 @@ public function handleClient(WebsocketClient $client, Request $request, Response
}
};

$websocket = new Websocket($server, $logger, $acceptor, $clientHandler);
$websocket = new Websocket(
httpServer: $server,
logger: $logger,
clientHandler: $clientHandler,
clientFactory: $clientFactory,
);

$router = new Router($server, new NullLogger(), $errorHandler);
$router->addRoute('GET', '/broadcast', $websocket);
Expand Down
6 changes: 3 additions & 3 deletions examples/stackexchange-questions/server.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use Amp\Log\ConsoleFormatter;
use Amp\Log\StreamHandler;
use Amp\Socket;
use Amp\Websocket\Server\AllowOriginAcceptor;
use Amp\Websocket\Server\AllowOriginClientFactory;
use Amp\Websocket\Server\Websocket;
use Amp\Websocket\Server\WebsocketClientGateway;
use Amp\Websocket\Server\WebsocketClientHandler;
Expand All @@ -40,7 +40,7 @@

$errorHandler = new DefaultErrorHandler();

$acceptor = new AllowOriginAcceptor(
$clientFactory = new AllowOriginClientFactory(
['http://localhost:1337', 'http://127.0.0.1:1337', 'http://[::1]:1337'],
);

Expand Down Expand Up @@ -100,8 +100,8 @@ public function handleClient(WebsocketClient $client, Request $request, Response
$websocket = new Websocket(
httpServer: $server,
logger: $logger,
acceptor: $acceptor,
clientHandler: $clientHandler,
clientFactory: $clientFactory,
);

$router = new Router($server, new NullLogger(), $errorHandler);
Expand Down
13 changes: 10 additions & 3 deletions src/AllowOriginAcceptor.php → src/AllowOriginClientFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
use Amp\Http\Server\ErrorHandler;
use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Socket\Socket;
use Amp\Websocket\WebsocketClient;

final class AllowOriginAcceptor implements WebsocketAcceptor
final class AllowOriginClientFactory implements WebsocketClientFactory
{
use ForbidCloning;
use ForbidSerialization;
Expand All @@ -20,7 +22,7 @@ final class AllowOriginAcceptor implements WebsocketAcceptor
public function __construct(
private readonly array $allowOrigins,
private readonly ErrorHandler $errorHandler = new Internal\UpgradeErrorHandler(),
private readonly WebsocketAcceptor $acceptor = new Rfc6455Acceptor(),
private readonly WebsocketClientFactory $clientFactory = new Rfc6455ClientFactory(),
) {
}

Expand All @@ -30,6 +32,11 @@ public function handleHandshake(Request $request): Response
return $this->errorHandler->handleError(HttpStatus::FORBIDDEN, 'Origin forbidden', $request);
}

return $this->acceptor->handleHandshake($request);
return $this->clientFactory->handleHandshake($request);
}

public function createClient(Request $request, Response $response, Socket $socket): WebsocketClient
{
return $this->clientFactory->createClient($request, $response, $socket);
}
}
90 changes: 0 additions & 90 deletions src/Rfc6455Acceptor.php

This file was deleted.

111 changes: 99 additions & 12 deletions src/Rfc6455ClientFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
use Amp\ByteStream\ResourceStream;
use Amp\ForbidCloning;
use Amp\ForbidSerialization;
use Amp\Http;
use Amp\Http\HttpStatus;
use Amp\Http\Server\ErrorHandler;
use Amp\Http\Server\Request;
use Amp\Http\Server\Response;
use Amp\Socket\Socket;
use Amp\Websocket\Compression\WebsocketCompressionContext;
use Amp\Websocket\Compression\WebsocketCompressionContextFactory;
use Amp\Websocket\ConstantRateLimit;
use Amp\Websocket\Parser\Rfc6455ParserFactory;
Expand All @@ -17,12 +21,16 @@
use Amp\Websocket\WebsocketClient;
use Amp\Websocket\WebsocketHeartbeatQueue;
use Amp\Websocket\WebsocketRateLimit;
use function Amp\Websocket\generateAcceptFromKey;

final class Rfc6455ClientFactory implements WebsocketClientFactory
{
use ForbidCloning;
use ForbidSerialization;

/** @var \WeakMap<Request, WebsocketCompressionContext> */
private readonly \WeakMap $compressionMap;

/**
* @param WebsocketCompressionContextFactory|null $compressionContextFactory Use null to disable compression.
* @param WebsocketHeartbeatQueue|null $heartbeatQueue Use null to disable automatic heartbeats (pings).
Expand All @@ -32,10 +40,99 @@ public function __construct(
private readonly ?WebsocketCompressionContextFactory $compressionContextFactory = null,
private readonly ?WebsocketHeartbeatQueue $heartbeatQueue = new PeriodicHeartbeatQueue(),
private readonly ?WebsocketRateLimit $rateLimit = new ConstantRateLimit(),
private readonly ErrorHandler $errorHandler = new Internal\UpgradeErrorHandler(),
private readonly WebsocketParserFactory $parserFactory = new Rfc6455ParserFactory(),
private readonly int $frameSplitThreshold = Rfc6455Client::DEFAULT_FRAME_SPLIT_THRESHOLD,
private readonly float $closePeriod = Rfc6455Client::DEFAULT_CLOSE_PERIOD,
) {
/** @var \WeakMap<Request, WebsocketCompressionContext> */
$this->compressionMap = new \WeakMap();
}

public function handleHandshake(Request $request): Response
{
if ($request->getMethod() !== 'GET') {
$response = $this->errorHandler->handleError(HttpStatus::METHOD_NOT_ALLOWED, request: $request);
$response->setHeader('allow', 'GET');
return $response;
}

if ($request->getProtocolVersion() !== '1.1') {
$response = $this->errorHandler->handleError(HttpStatus::HTTP_VERSION_NOT_SUPPORTED, request: $request);
$response->setHeader('upgrade', 'websocket');
return $response;
}

if ($request->getBody()->buffer() !== '') {
return $this->errorHandler->handleError(HttpStatus::BAD_REQUEST, request: $request);
}

$hasUpgradeWebsocket = false;
foreach ($request->getHeaderArray('upgrade') as $value) {
if (\strcasecmp($value, 'websocket') === 0) {
$hasUpgradeWebsocket = true;
break;
}
}
if (!$hasUpgradeWebsocket) {
$response = $this->errorHandler->handleError(HttpStatus::UPGRADE_REQUIRED, request: $request);
$response->setHeader('upgrade', 'websocket');
return $response;
}

$hasConnectionUpgrade = false;
foreach ($request->getHeaderArray('connection') as $value) {
$values = \array_map('trim', \explode(',', $value));

foreach ($values as $token) {
if (\strcasecmp($token, 'upgrade') === 0) {
$hasConnectionUpgrade = true;
break;
}
}
}

if (!$hasConnectionUpgrade) {
$reason = 'Bad Request: "Connection: Upgrade" header required';
$response = $this->errorHandler->handleError(HttpStatus::UPGRADE_REQUIRED, $reason, $request);
$response->setHeader('upgrade', 'websocket');
return $response;
}

if (!$acceptKey = $request->getHeader('sec-websocket-key')) {
$reason = 'Bad Request: "Sec-Websocket-Key" header required';
return $this->errorHandler->handleError(HttpStatus::BAD_REQUEST, $reason, $request);
}

if (!\in_array('13', $request->getHeaderArray('sec-websocket-version'), true)) {
$reason = 'Bad Request: Requested Websocket version unavailable';
$response = $this->errorHandler->handleError(HttpStatus::BAD_REQUEST, $reason, $request);
$response->setHeader('sec-websocket-version', '13');
return $response;
}

$response = new Response(HttpStatus::SWITCHING_PROTOCOLS, [
'connection' => 'upgrade',
'upgrade' => 'websocket',
'sec-websocket-accept' => generateAcceptFromKey($acceptKey),
]);

if ($this->compressionContextFactory) {
$extensions = Http\splitHeader($request, 'sec-websocket-extensions') ?? [];

foreach ($extensions as $extension) {
if ($compressionContext = $this->compressionContextFactory->fromClientHeader($extension, $headerLine)) {
/** @psalm-suppress InaccessibleProperty WeakMap implements ArrayAccess. */
$this->compressionMap[$request] = $compressionContext;

/** @psalm-suppress PossiblyNullArgument */
$response->setHeader('sec-websocket-extensions', $headerLine);
break;
}
}
}

return $response;
}

public function createClient(
Expand Down Expand Up @@ -64,18 +161,8 @@ public function createClient(
}
}

$compressionContext = null;
if ($this->compressionContextFactory) {
$extensions = \array_map('trim', \explode(',', (string) $request->getHeader('sec-websocket-extensions')));

foreach ($extensions as $extension) {
if ($compressionContext = $this->compressionContextFactory->fromClientHeader($extension, $headerLine)) {
/** @psalm-suppress PossiblyNullArgument */
$response->setHeader('sec-websocket-extensions', $headerLine);
break;
}
}
}
$compressionContext = $this->compressionMap[$request] ?? null;
unset($this->compressionMap[$request]);

return new Rfc6455Client(
socket: $socket,
Expand Down
Loading

0 comments on commit 15b3ce1

Please sign in to comment.