Skip to content

Commit

Permalink
Replace Request#setHeaders with addHeader (#30588)
Browse files Browse the repository at this point in the history
Adding headers rather than setting them all at once seems more
user-friendly and we already do it in a similar way for parameters
(see Request#addParameter).
  • Loading branch information
javanna authored May 22, 2018
1 parent 0d37ac4 commit a17d6ca
Show file tree
Hide file tree
Showing 20 changed files with 166 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.admin.cluster.repositories.get.GetRepositoriesRequest;
import org.elasticsearch.action.admin.cluster.repositories.get.GetRepositoriesResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.delete.DeleteRequest;
Expand Down Expand Up @@ -592,7 +590,7 @@ protected final <Req extends ActionRequest, Resp> Resp performRequest(Req reques
throw validationException;
}
Request req = requestConverter.apply(request);
req.setHeaders(headers);
addHeaders(req, headers);
Response response;
try {
response = client.performRequest(req);
Expand Down Expand Up @@ -642,12 +640,19 @@ protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req r
listener.onFailure(e);
return;
}
req.setHeaders(headers);
addHeaders(req, headers);

ResponseListener responseListener = wrapResponseListener(responseConverter, listener, ignores);
client.performRequestAsync(req, responseListener);
}

private static void addHeaders(Request request, Header... headers) {
Objects.requireNonNull(headers, "headers cannot be null");
for (Header header : headers) {
request.addHeader(header.getName(), header.getValue());
}
}

final <Resp> ResponseListener wrapResponseListener(CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> actionListener, Set<Integer> ignores) {
return new ResponseListener() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ public void initClients() throws IOException {
final RestClient restClient = mock(RestClient.class);
restHighLevelClient = new CustomRestClient(restClient);

doAnswer(inv -> mockPerformRequest(((Request) inv.getArguments()[0]).getHeaders()[0]))
doAnswer(inv -> mockPerformRequest(((Request) inv.getArguments()[0]).getHeaders().iterator().next()))
.when(restClient)
.performRequest(any(Request.class));

doAnswer(inv -> mockPerformRequestAsync(
((Request) inv.getArguments()[0]).getHeaders()[0],
((Request) inv.getArguments()[0]).getHeaders().iterator().next(),
(ResponseListener) inv.getArguments()[1]))
.when(restClient)
.performRequestAsync(any(Request.class), any(ResponseListener.class));
Expand Down
64 changes: 46 additions & 18 deletions client/rest/src/main/java/org/elasticsearch/client/Request.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

package org.elasticsearch.client;

import org.apache.http.entity.ContentType;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.http.nio.entity.NStringEntity;
import org.apache.http.nio.protocol.HttpAsyncResponseConsumer;

import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand All @@ -36,13 +39,12 @@
* HTTP Request to Elasticsearch.
*/
public final class Request {
private static final Header[] NO_HEADERS = new Header[0];
private final String method;
private final String endpoint;
private final Map<String, String> parameters = new HashMap<>();
private final List<Header> headers = new ArrayList<>();

private HttpEntity entity;
private Header[] headers = NO_HEADERS;
private HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory =
HttpAsyncResponseConsumerFactory.DEFAULT;

Expand Down Expand Up @@ -125,21 +127,19 @@ public HttpEntity getEntity() {
}

/**
* Set the headers to attach to the request.
* Add the provided header to the request.
*/
public void setHeaders(Header... headers) {
Objects.requireNonNull(headers, "headers cannot be null");
for (Header header : headers) {
Objects.requireNonNull(header, "header cannot be null");
}
this.headers = headers;
public void addHeader(String name, String value) {
Objects.requireNonNull(name, "header name cannot be null");
Objects.requireNonNull(value, "header value cannot be null");
this.headers.add(new ReqHeader(name, value));
}

/**
* Headers to attach to the request.
*/
public Header[] getHeaders() {
return headers;
List<Header> getHeaders() {
return Collections.unmodifiableList(headers);
}

/**
Expand Down Expand Up @@ -175,13 +175,13 @@ public String toString() {
if (entity != null) {
b.append(", entity=").append(entity);
}
if (headers.length > 0) {
if (headers.size() > 0) {
b.append(", headers=");
for (int h = 0; h < headers.length; h++) {
for (int h = 0; h < headers.size(); h++) {
if (h != 0) {
b.append(',');
}
b.append(headers[h].toString());
b.append(headers.get(h).toString());
}
}
if (httpAsyncResponseConsumerFactory != HttpAsyncResponseConsumerFactory.DEFAULT) {
Expand All @@ -204,12 +204,40 @@ public boolean equals(Object obj) {
&& endpoint.equals(other.endpoint)
&& parameters.equals(other.parameters)
&& Objects.equals(entity, other.entity)
&& Arrays.equals(headers, other.headers)
&& headers.equals(other.headers)
&& httpAsyncResponseConsumerFactory.equals(other.httpAsyncResponseConsumerFactory);
}

@Override
public int hashCode() {
return Objects.hash(method, endpoint, parameters, entity, Arrays.hashCode(headers), httpAsyncResponseConsumerFactory);
return Objects.hash(method, endpoint, parameters, entity, headers.hashCode(), httpAsyncResponseConsumerFactory);
}

/**
* Custom implementation of {@link BasicHeader} that overrides equals and hashCode.
*/
static final class ReqHeader extends BasicHeader {

ReqHeader(String name, String value) {
super(name, value);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other instanceof ReqHeader) {
Header otherHeader = (Header) other;
return Objects.equals(getName(), otherHeader.getName()) &&
Objects.equals(getValue(), otherHeader.getValue());
}
return false;
}

@Override
public int hashCode() {
return Objects.hash(getName(), getValue());
}
}
}
34 changes: 24 additions & 10 deletions client/rest/src/main/java/org/elasticsearch/client/RestClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void performRequestAsync(Request request, ResponseListener responseListen
@Deprecated
public Response performRequest(String method, String endpoint, Header... headers) throws IOException {
Request request = new Request(method, endpoint);
request.setHeaders(headers);
addHeaders(request, headers);
return performRequest(request);
}

Expand All @@ -237,7 +237,7 @@ public Response performRequest(String method, String endpoint, Header... headers
public Response performRequest(String method, String endpoint, Map<String, String> params, Header... headers) throws IOException {
Request request = new Request(method, endpoint);
addParameters(request, params);
request.setHeaders(headers);
addHeaders(request, headers);
return performRequest(request);
}

Expand All @@ -264,7 +264,7 @@ public Response performRequest(String method, String endpoint, Map<String, Strin
Request request = new Request(method, endpoint);
addParameters(request, params);
request.setEntity(entity);
request.setHeaders(headers);
addHeaders(request, headers);
return performRequest(request);
}

Expand Down Expand Up @@ -305,7 +305,7 @@ public Response performRequest(String method, String endpoint, Map<String, Strin
addParameters(request, params);
request.setEntity(entity);
request.setHttpAsyncResponseConsumerFactory(httpAsyncResponseConsumerFactory);
request.setHeaders(headers);
addHeaders(request, headers);
return performRequest(request);
}

Expand All @@ -325,7 +325,7 @@ public void performRequestAsync(String method, String endpoint, ResponseListener
Request request;
try {
request = new Request(method, endpoint);
request.setHeaders(headers);
addHeaders(request, headers);
} catch (Exception e) {
responseListener.onFailure(e);
return;
Expand All @@ -352,7 +352,7 @@ public void performRequestAsync(String method, String endpoint, Map<String, Stri
try {
request = new Request(method, endpoint);
addParameters(request, params);
request.setHeaders(headers);
addHeaders(request, headers);
} catch (Exception e) {
responseListener.onFailure(e);
return;
Expand Down Expand Up @@ -383,7 +383,7 @@ public void performRequestAsync(String method, String endpoint, Map<String, Stri
request = new Request(method, endpoint);
addParameters(request, params);
request.setEntity(entity);
request.setHeaders(headers);
addHeaders(request, headers);
} catch (Exception e) {
responseListener.onFailure(e);
return;
Expand Down Expand Up @@ -420,7 +420,7 @@ public void performRequestAsync(String method, String endpoint, Map<String, Stri
addParameters(request, params);
request.setEntity(entity);
request.setHttpAsyncResponseConsumerFactory(httpAsyncResponseConsumerFactory);
request.setHeaders(headers);
addHeaders(request, headers);
} catch (Exception e) {
responseListener.onFailure(e);
return;
Expand Down Expand Up @@ -539,9 +539,9 @@ public void cancelled() {
});
}

private void setHeaders(HttpRequest httpRequest, Header[] requestHeaders) {
private void setHeaders(HttpRequest httpRequest, Collection<Header> requestHeaders) {
// request headers override default headers, so we don't add default headers if they exist as request headers
final Set<String> requestNames = new HashSet<>(requestHeaders.length);
final Set<String> requestNames = new HashSet<>(requestHeaders.size());
for (Header requestHeader : requestHeaders) {
httpRequest.addHeader(requestHeader);
requestNames.add(requestHeader.getName());
Expand Down Expand Up @@ -877,10 +877,24 @@ private static class HostTuple<T> {
}
}

/**
* Add all headers from the provided varargs argument to a {@link Request}. This only exists
* to support methods that exist for backwards compatibility.
*/
@Deprecated
private static void addHeaders(Request request, Header... headers) {
Objects.requireNonNull(headers, "headers cannot be null");
for (Header header : headers) {
Objects.requireNonNull(header, "header cannot be null");
request.addHeader(header.getName(), header.getValue());
}
}

/**
* Add all parameters from a map to a {@link Request}. This only exists
* to support methods that exist for backwards compatibility.
*/
@Deprecated
private static void addParameters(Request request, Map<String, String> parameters) {
Objects.requireNonNull(parameters, "parameters cannot be null");
for (Map.Entry<String, String> entry : parameters.entrySet()) {
Expand Down
Loading

0 comments on commit a17d6ca

Please sign in to comment.