Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extra headers #1457

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/oidc-client-ts.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ export class OidcClient {
// (undocumented)
processResourceOwnerPasswordCredentials({ username, password, skipUserInfo, extraTokenParams, }: ProcessResourceOwnerPasswordCredentialsArgs): Promise<SigninResponse>;
// (undocumented)
processSigninResponse(url: string): Promise<SigninResponse>;
processSigninResponse(url: string, extraHeaders?: Record<string, ExtraHeader>): Promise<SigninResponse>;
// (undocumented)
processSignoutResponse(url: string): Promise<SignoutResponse>;
// (undocumented)
Expand All @@ -333,7 +333,7 @@ export class OidcClient {
// (undocumented)
protected readonly _tokenClient: TokenClient;
// (undocumented)
useRefreshToken({ state, redirect_uri, resource, timeoutInSeconds, extraTokenParams, }: UseRefreshTokenArgs): Promise<SigninResponse>;
useRefreshToken({ state, redirect_uri, resource, timeoutInSeconds, extraHeaders, extraTokenParams, }: UseRefreshTokenArgs): Promise<SigninResponse>;
// Warning: (ae-forgotten-export) The symbol "ResponseValidator" needs to be exported by the entry point index.d.ts
//
// (undocumented)
Expand Down Expand Up @@ -924,6 +924,8 @@ export class User {

// @public (undocumented)
export interface UseRefreshTokenArgs {
// (undocumented)
extraHeaders?: Record<string, ExtraHeader>;
// (undocumented)
extraTokenParams?: Record<string, unknown>;
// (undocumented)
Expand Down Expand Up @@ -1007,7 +1009,7 @@ export class UserManager {
// (undocumented)
storeUser(user: User | null): Promise<void>;
// (undocumented)
protected _useRefreshToken(args: UseRefreshTokenArgs): Promise<User>;
protected _useRefreshToken(args: UseRefreshTokenArgs, extraHeaders?: Record<string, ExtraHeader>): Promise<User>;
// (undocumented)
protected get _userStoreKey(): string;
}
Expand Down
26 changes: 26 additions & 0 deletions src/JsonService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ErrorResponse } from "./errors";
import { JsonService } from "./JsonService";

import { mocked } from "jest-mock";
import type { ExtraHeader } from "./OidcClientSettings";

describe("JsonService", () => {
let subject: JsonService;
Expand Down Expand Up @@ -339,6 +340,31 @@ describe("JsonService", () => {
);
});

it("should fetch with extraHeaders if supplied", async () => {
// act
const extraHeaders: Record<string, ExtraHeader> = {
"extraHeader": "some random header value",
};
await expect(subject.postForm("http://test", { body: new URLSearchParams("payload=dummy"), extraHeaders })).rejects.toThrow();
await expect(subject.postForm("http://test", { body: new URLSearchParams("payload=dummy"), basicAuth: "basicAuth", extraHeaders })).rejects.toThrow();

// assert
expect(fetch).toBeCalledTimes(2);
expect(fetch).toHaveBeenLastCalledWith(
"http://test",
expect.objectContaining({
headers: {
Accept: "application/json",
Authorization: "Basic basicAuth",
"Content-Type": "application/x-www-form-urlencoded",
extraHeader: expect.any(String),
},
method: "POST",
body: new URLSearchParams(),
}),
);
});

it("should set payload as body", async () => {
// act
await expect(subject.postForm("http://test", { body: new URLSearchParams("payload=dummy") })).rejects.toThrow();
Expand Down
3 changes: 3 additions & 0 deletions src/JsonService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export interface PostFormOpts {
basicAuth?: string;
timeoutInSeconds?: number;
initCredentials?: "same-origin" | "include" | "omit";
extraHeaders?: Record<string, ExtraHeader>;
}

/**
Expand Down Expand Up @@ -131,11 +132,13 @@ export class JsonService {
basicAuth,
timeoutInSeconds,
initCredentials,
extraHeaders,
}: PostFormOpts): Promise<Record<string, unknown>> {
const logger = this._logger.create("postForm");
const headers: HeadersInit = {
"Accept": this._contentTypes.join(", "),
"Content-Type": "application/x-www-form-urlencoded",
...extraHeaders,
};
if (basicAuth !== undefined) {
headers["Authorization"] = "Basic " + basicAuth;
Expand Down
64 changes: 62 additions & 2 deletions src/OidcClient.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { JwtUtils } from "./utils";
import type { ErrorResponse } from "./errors";
import type { JwtClaims } from "./Claims";
import { OidcClient } from "./OidcClient";
import { type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";
import { type ExtraHeader, type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";
import { SigninState } from "./SigninState";
import { State } from "./State";
import { SigninRequest } from "./SigninRequest";
Expand Down Expand Up @@ -335,7 +335,32 @@ describe("OidcClient", () => {
const response = await subject.processSigninResponse("http://app/cb?state=1");

// assert
expect(validateSigninResponseMock).toHaveBeenCalledWith(response, item);
expect(validateSigninResponseMock).toHaveBeenCalledWith(response, item, undefined);
});

it("should pass on extraHeaders if supplied", async () => {
// arrange
const item = await SigninState.create({
id: "1",
authority: "authority",
client_id: "client",
redirect_uri: "http://app/cb",
scope: "scope",
request_type: "type",
});

const extraHeaders: Record<string, ExtraHeader> = { "foo": "bar" };

jest.spyOn(subject.settings.stateStore, "remove")
.mockImplementation(async () => item.toStorageString());
const validateSigninResponseMock = jest.spyOn(subject["_validator"], "validateSigninResponse")
.mockResolvedValue();

// act
const response = await subject.processSigninResponse("http://app/cb?state=1", extraHeaders);

// assert
expect(validateSigninResponseMock).toHaveBeenCalledWith(response, item, extraHeaders);
});
});

Expand Down Expand Up @@ -514,6 +539,41 @@ describe("OidcClient", () => {
// assert
.rejects.toThrow("sub in id_token does not match current sub");
});

it("should pass extraHeaders to tokenClient.exchangeRefreshToken if supplied", async () => {
// arrange
const tokenResponse = {
access_token: "new_access_token",
};
const exchangeRefreshTokenMock =
jest.spyOn(subject["_tokenClient"], "exchangeRefreshToken")
.mockResolvedValue(tokenResponse);
jest.spyOn(JwtUtils, "decode").mockReturnValue({ sub: "sub" });
const state = new RefreshState({
refresh_token: "refresh_token",
id_token: "id_token",
session_state: "session_state",
scope: "openid",
profile: {} as UserProfile,
});
const extraHeaders: Record<string, ExtraHeader> = { "foo": "bar" };

// act
const response = await subject.useRefreshToken({ state, resource: "resource", extraHeaders: extraHeaders });

// assert
expect(exchangeRefreshTokenMock).toHaveBeenCalledWith( {
refresh_token: "refresh_token",
scope: "openid",
timeoutInSeconds: undefined,
resource: "resource",
extraHeaders: extraHeaders,
});
expect(response).toBeInstanceOf(SigninResponse);
expect(response).toMatchObject(tokenResponse);
expect(response).toHaveProperty("session_state", state.session_state);
expect(response).toHaveProperty("scope", state.scope);
});
});

describe("createSignoutRequest", () => {
Expand Down
10 changes: 7 additions & 3 deletions src/OidcClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import { Logger, UrlUtils } from "./utils";
import { ErrorResponse } from "./errors";
import { type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";
import { type ExtraHeader, type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";
import { ResponseValidator } from "./ResponseValidator";
import { MetadataService } from "./MetadataService";
import type { RefreshState } from "./RefreshState";
Expand Down Expand Up @@ -39,6 +39,8 @@ export interface UseRefreshTokenArgs {
timeoutInSeconds?: number;

state: RefreshState;

extraHeaders?: Record<string, ExtraHeader>;
}

/**
Expand Down Expand Up @@ -164,12 +166,12 @@ export class OidcClient {
return { state, response };
}

public async processSigninResponse(url: string): Promise<SigninResponse> {
public async processSigninResponse(url: string, extraHeaders?: Record<string, ExtraHeader>): Promise<SigninResponse> {
const logger = this._logger.create("processSigninResponse");

const { state, response } = await this.readSigninResponseState(url, true);
logger.debug("received state from storage; validating response");
await this._validator.validateSigninResponse(response, state);
await this._validator.validateSigninResponse(response, state, extraHeaders);
return response;
}

Expand All @@ -191,6 +193,7 @@ export class OidcClient {
redirect_uri,
resource,
timeoutInSeconds,
extraHeaders,
extraTokenParams,
}: UseRefreshTokenArgs): Promise<SigninResponse> {
const logger = this._logger.create("useRefreshToken");
Expand All @@ -215,6 +218,7 @@ export class OidcClient {
redirect_uri,
resource,
timeoutInSeconds,
extraHeaders,
...extraTokenParams,
});
const response = new SigninResponse(new URLSearchParams());
Expand Down
9 changes: 5 additions & 4 deletions src/ResponseValidator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ErrorResponse } from "./errors";
import type { MetadataService } from "./MetadataService";
import { UserInfoService } from "./UserInfoService";
import { TokenClient } from "./TokenClient";
import type { OidcClientSettingsStore } from "./OidcClientSettings";
import type { ExtraHeader, OidcClientSettingsStore } from "./OidcClientSettings";
import type { SigninState } from "./SigninState";
import type { SigninResponse } from "./SigninResponse";
import type { State } from "./State";
Expand All @@ -30,13 +30,13 @@ export class ResponseValidator {
protected readonly _claimsService: ClaimsService,
) {}

public async validateSigninResponse(response: SigninResponse, state: SigninState): Promise<void> {
public async validateSigninResponse(response: SigninResponse, state: SigninState, extraHeaders?: Record<string, ExtraHeader>): Promise<void> {
const logger = this._logger.create("validateSigninResponse");

this._processSigninState(response, state);
logger.debug("state processed");

await this._processCode(response, state);
await this._processCode(response, state, extraHeaders);
logger.debug("code processed");

if (response.isOpenId) {
Expand Down Expand Up @@ -169,7 +169,7 @@ export class ResponseValidator {
logger.debug("user info claims received, updated profile:", response.profile);
}

protected async _processCode(response: SigninResponse, state: SigninState): Promise<void> {
protected async _processCode(response: SigninResponse, state: SigninState, extraHeaders?: Record<string, ExtraHeader>): Promise<void> {
const logger = this._logger.create("_processCode");
if (response.code) {
logger.debug("Validating code");
Expand All @@ -179,6 +179,7 @@ export class ResponseValidator {
code: response.code,
redirect_uri: state.redirect_uri,
code_verifier: state.code_verifier,
extraHeaders: extraHeaders,
...state.extraTokenParams,
});
Object.assign(response, tokenResponse);
Expand Down
47 changes: 46 additions & 1 deletion src/TokenClient.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { CryptoUtils } from "./utils";
import { TokenClient } from "./TokenClient";
import { MetadataService } from "./MetadataService";
import { type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";
import { type ExtraHeader, type OidcClientSettings, OidcClientSettingsStore } from "./OidcClientSettings";

describe("TokenClient", () => {
let settings: OidcClientSettings;
Expand Down Expand Up @@ -142,6 +142,28 @@ describe("TokenClient", () => {
}),
);
});

it("should call postForm with extraHeaders if extraHeaders are supplied", async () => {
// arrange
const getTokenEndpointMock = jest.spyOn(subject["_metadataService"], "getTokenEndpoint")
.mockResolvedValue("http://sts/token_endpoint");
const postFormMock = jest.spyOn(subject["_jsonService"], "postForm")
.mockResolvedValue({});
const extraHeaders: Record<string, ExtraHeader> = { "foo": "bar" };
// act
await subject.exchangeCode({ code: "code", code_verifier: "code_verifier", extraHeaders: extraHeaders });

// assert
expect(getTokenEndpointMock).toHaveBeenCalledWith(false);
expect(postFormMock).toHaveBeenCalledWith(
"http://sts/token_endpoint",
expect.objectContaining({
body: expect.any(URLSearchParams),
basicAuth: undefined,
extraHeaders: extraHeaders,
}),
);
});
});

describe("exchangeCredentials", () => {
Expand Down Expand Up @@ -360,6 +382,29 @@ describe("TokenClient", () => {
}),
);
});

it("should call postForm with extraHeaders if extraHeaders are supplied", async () => {
// arrange
const getTokenEndpointMock = jest.spyOn(subject["_metadataService"], "getTokenEndpoint")
.mockResolvedValue("http://sts/token_endpoint");
const postFormMock = jest.spyOn(subject["_jsonService"], "postForm")
.mockResolvedValue({});
const extraHeaders: Record<string, ExtraHeader> = { "foo": "bar" };
// act
await subject.exchangeRefreshToken({ refresh_token: "refresh_token", extraHeaders: extraHeaders });

// assert
expect(getTokenEndpointMock).toHaveBeenCalledWith(false);
expect(postFormMock).toHaveBeenCalledWith(
"http://sts/token_endpoint",
expect.objectContaining({
body: expect.any(URLSearchParams),
basicAuth: undefined,
timeoutInSeconds: undefined,
extraHeaders: extraHeaders,
}),
);
});
});

describe("revoke", () => {
Expand Down
Loading
Loading