Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatch all exchanges to execution context #74

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.undertow.server.handlers.BlockingHandler
import io.undertow.util.HttpString

import scala.concurrent.ExecutionContext
import java.util.concurrent.Executor

/**
* A combination of [[cask.Main]] and [[cask.Routes]], ideal for small
Expand Down Expand Up @@ -46,9 +47,14 @@ abstract class Main{

def dispatchTrie = Main.prepareDispatchTrie(allRoutes)

def defaultHandler = new BlockingHandler(
new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError)
)
def defaultHandler = new BlockingHandler(new Main.DefaultHandler(
dispatchTrie,
mainDecorators,
debugMode,
handleNotFound,
handleMethodNotAllowed,
handleEndpointError,
executionContext))

def handleNotFound() = Main.defaultHandleNotFound()

Expand Down Expand Up @@ -77,55 +83,58 @@ object Main{
debugMode: Boolean,
handleNotFound: () => Response.Raw,
handleMethodNotAllowed: () => Response.Raw,
handleError: (Routes, EndpointMetadata[_], Result.Error) => Response.Raw)
handleError: (Routes, EndpointMetadata[_], Result.Error) => Response.Raw,
executor: Executor)
(implicit log: Logger) extends HttpHandler() {
def handleRequest(exchange: HttpServerExchange): Unit = try {
// println("Handling Request: " + exchange.getRequestPath)
val (effectiveMethod, runner) = if (exchange.getRequestHeaders.getFirst("Upgrade") == "websocket") {
Tuple2(
"websocket",
(r: Any) =>
r.asInstanceOf[WebsocketResult] match{
case l: WsHandler =>
io.undertow.Handlers.websocket(l).handleRequest(exchange)
case l: WebsocketResult.Listener =>
io.undertow.Handlers.websocket(l.value).handleRequest(exchange)
case r: WebsocketResult.Response[Response.Data] =>
Main.writeResponse(exchange, r.value)
}
def handleRequest(exchange: HttpServerExchange): Unit = exchange.dispatch(executor, new Runnable {
def run(): Unit = try {
// println("Handling Request: " + exchange.getRequestPath)
val (effectiveMethod, runner) = if (exchange.getRequestHeaders.getFirst("Upgrade") == "websocket") {
Tuple2(
"websocket",
(r: Any) =>
r.asInstanceOf[WebsocketResult] match{
case l: WsHandler =>
io.undertow.Handlers.websocket(l).handleRequest(exchange)
case l: WebsocketResult.Listener =>
io.undertow.Handlers.websocket(l.value).handleRequest(exchange)
case r: WebsocketResult.Response[Response.Data] =>
Main.writeResponse(exchange, r.value)
}
)
} else Tuple2(
exchange.getRequestMethod.toString.toLowerCase(),
(r: Any) => Main.writeResponse(exchange, r.asInstanceOf[Response.Raw])
)
} else Tuple2(
exchange.getRequestMethod.toString.toLowerCase(),
(r: Any) => Main.writeResponse(exchange, r.asInstanceOf[Response.Raw])
)

dispatchTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match {
case None => Main.writeResponse(exchange, handleNotFound())
case Some((methodMap, routeBindings, remaining)) =>
methodMap.get(effectiveMethod) match {
case None => Main.writeResponse(exchange, handleMethodNotAllowed())
case Some((routes, metadata)) =>
Decorator.invoke(
Request(exchange, remaining),
metadata.endpoint,
metadata.entryPoint.asInstanceOf[EntryPoint[Routes, _]],
routes,
routeBindings,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil
) match {
case Result.Success(res) => runner(res)
case e: Result.Error =>
Main.writeResponse(
exchange,
handleError(routes, metadata, e)
)
}
}
dispatchTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match {
case None => Main.writeResponse(exchange, handleNotFound())
case Some((methodMap, routeBindings, remaining)) =>
methodMap.get(effectiveMethod) match {
case None => Main.writeResponse(exchange, handleMethodNotAllowed())
case Some((routes, metadata)) =>
Decorator.invoke(
Request(exchange, remaining),
metadata.endpoint,
metadata.entryPoint.asInstanceOf[EntryPoint[Routes, _]],
routes,
routeBindings,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil
) match {
case Result.Success(res) => runner(res)
case e: Result.Error =>
Main.writeResponse(
exchange,
handleError(routes, metadata, e)
)
}
}
}
}catch{case e: Throwable =>
e.printStackTrace()
}
}catch{case e: Throwable =>
e.printStackTrace()
}
})
}

def defaultHandleNotFound(): Response.Raw = {
Expand Down
23 changes: 23 additions & 0 deletions docs/pages/2 - Main Customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,26 @@ useful stack traces or metadata for debugging if `debugMode = true`.
Any `cask.Decorator`s that you want to apply to all routes and all endpoints in
the entire web application. Useful for inserting application-wide
instrumentation, logging, security-checks, and similar things.

## def createExecutionContext

A `scala.concurrent.ExecutionContextExecutorService` to which all requests
(including WebSockets) are dispatched. By default uses a fixed thread pool with
N threads where N is the number of CPUs.

Can be overridden by a custom executor, e.g. suppose you want to use an
unbounded pool of JDK 19's [virtual threads](https://openjdk.org/jeps/425):

```scala
import java.util.concurrent.Executors
import scala.concurrent.ExecutionContext

object MyCaskApp extends cask.MainRoutes {
override def createExecutionContext =
ExecutionContext.fromExecutorService(Executors.newVirtualThreadPerTaskExecutor())

@cask.get("/")
def hello() = "Hello, World!"
}
```