Skip to content

Commit

Permalink
add optional support for PKCE
Browse files Browse the repository at this point in the history
  • Loading branch information
bodewig committed Mar 27, 2020
1 parent 7d3cbe5 commit 650ad2a
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 2 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
03/27/2020
- added optional support for RFC 7636 "Proof Key for Code Exchange"

02/06/2020
- ability to disable keepalive from lua-resty-http
By disabling keepalive we disable the native connection pool,
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ h2JHukolz9xf6qN61QMLSd83+kwoBr2drp6xg3eGDLIkQCQLrkY=
-- return req
-- end,
-- use_pkce = false,
-- when set to true the "Proof Key for Code Exchange" as
-- defined in RFC 7636 will be used. The code challenge
-- method will alwas be S256
}
-- call authenticate for OpenID Connect user authentication
Expand Down
25 changes: 23 additions & 2 deletions lib/resty/openidc.lua
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ local function openidc_base64_url_decode(input)
return unb64(input)
end

-- perform base64url encoding
local function openidc_base64_url_encode(input)
local output = b64(input, true)
return output:gsub('+', '-'):gsub('/', '_')
end

local function openidc_combine_uri(uri, params)
if params == nil or next(params) == nil then
return uri
Expand All @@ -310,6 +316,12 @@ local function decorate_request(http_request_decorator, req)
return http_request_decorator and http_request_decorator(req) or req
end

local function openidc_s256(verifier)
local sha256 = (require 'resty.sha256'):new()
sha256:update(verifier)
return openidc_base64_url_encode(sha256:final())
end

-- send the browser of to the OP's authorization endpoint
local function openidc_authorize(opts, session, target_url, prompt)
local resty_random = require("resty.random")
Expand All @@ -319,6 +331,7 @@ local function openidc_authorize(opts, session, target_url, prompt)
local state = resty_string.to_hex(resty_random.bytes(16))
local nonce = (opts.use_nonce == nil or opts.use_nonce)
and resty_string.to_hex(resty_random.bytes(16))
local code_verifier = opts.use_pkce and openidc_base64_url_encode(resty_random.bytes(32))

-- assemble the parameters to the authentication request
local params = {
Expand All @@ -341,6 +354,11 @@ local function openidc_authorize(opts, session, target_url, prompt)
params.display = opts.display
end

if code_verifier then
params.code_challenge_method = 'S256'
params.code_challenge = openidc_s256(code_verifier)
end

-- merge any provided extra parameters
if opts.authorization_params then
for k, v in pairs(opts.authorization_params) do params[k] = v end
Expand All @@ -350,6 +368,7 @@ local function openidc_authorize(opts, session, target_url, prompt)
session.data.original_url = target_url
session.data.state = state
session.data.nonce = nonce
session.data.code_verifier = code_verifier
session.data.last_authenticated = ngx.time()

if opts.lifecycle and opts.lifecycle.on_created then
Expand Down Expand Up @@ -1093,7 +1112,8 @@ local function openidc_authorization_response(opts, session)
grant_type = "authorization_code",
code = args.code,
redirect_uri = openidc_get_redirect_uri(opts),
state = session.data.state
state = session.data.state,
code_verifier = session.data.code_verifier
}

log(DEBUG, "Authentication with OP done -> Calling OP Token Endpoint to obtain tokens")
Expand All @@ -1113,9 +1133,10 @@ local function openidc_authorization_response(opts, session)

-- mark this sessions as authenticated
session.data.authenticated = true
-- clear state and nonce to protect against potential misuse
-- clear state, nonce and code_verifier to protect against potential misuse
session.data.nonce = nil
session.data.state = nil
session.data.code_verifier = nil
if store_in_session(opts, 'id_token') then
session.data.id_token = id_token
end
Expand Down
90 changes: 90 additions & 0 deletions tests/spec/pkce_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
local http = require('socket.http')
local test_support = require('test_support')
require 'busted.runner'()

describe('when pkce is disabled in authorize request', function()
test_support.start_server()
teardown(test_support.stop_server)

local _, status, headers = http.request({
url = 'http://127.0.0.1/default/t',
redirect = false
})
it('there is no code_challenge parameter', function()
assert.falsy(string.match(headers['location'], '.*code_challenge=.*'))
end)
it('there is no code_challenge_method parameter', function()
assert.falsy(string.match(headers['location'], '.*code_challenge_method=.*'))
end)
end)

describe('when pkce is enabled in authorize request', function()
test_support.start_server({oidc_opts = { use_pkce = true } })
teardown(test_support.stop_server)

local _, status, headers = http.request({
url = 'http://127.0.0.1/default/t',
redirect = false
})
it('there is a code_challenge parameter', function()
assert.truthy(string.match(headers['location'], '.*code_challenge=.*'))
end)
it('there is a S256 code_challenge_method parameter', function()
assert.truthy(string.match(headers['location'], '.*code_challenge_method=S256.*'))
end)
end)

local function assert_token_endpoint_call_contains(s, case_insensitive)
assert.error_log_contains("Received token request: .*" .. s .. ".*",
case_insensitive)
end

local function assert_token_endpoint_call_doesnt_contain(s, case_insensitive)
assert.is_not.error_log_contains("Received token request: .*" .. s .. ".*",
case_insensitive)
end

describe('when pkce is disabled and the token endpoint is invoked', function()
test_support.start_server()
teardown(test_support.stop_server)
test_support.login()
it("the request doesn't contain a code_verifier", function()
assert_token_endpoint_call_doesnt_contain('.*code_verifier=.*')
end)
end)

describe('when pkce is disabled and the token endpoint is invoked', function()
test_support.start_server({oidc_opts = { use_pkce = true } })
teardown(test_support.stop_server)

local _, _, headers = http.request({
url = "http://127.0.0.1/default/t",
redirect = false
})
local state = test_support.grab(headers, 'state')
local code_challenge = test_support.grab(headers, 'code_challenge')
test_support.register_nonce(headers)
http.request({
url = "http://127.0.0.1/default/redirect_uri?code=foo&state=" .. state,
headers = { cookie = test_support.extract_cookies(headers) },
redirect = false
})

local log = test_support.load("/tmp/server/logs/error.log")
local code_verifier = log:match('Received token request: .*code_verifier=([^&]-)[&,]')
it('the request contains a code_verifier', function()
assert.truthy(code_verifier)
end)
it('hashing the code verifier leads to the challenge', function()
local as_base64 = function(s)
local rem = #s % 4
if rem > 0 then
s = s .. string.rep('=', 4 - rem)
end
return s:gsub('-', '+'):gsub('_', '/')
end
local challenge = as_base64(code_challenge)
local hashed_verifier = (require 'mime').b64((require 'sha2').bytes(code_verifier))
assert.are.equals(hashed_verifier, challenge)
end)
end)
133 changes: 133 additions & 0 deletions tests/spec/sha2.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
-- Copy of https://github.com/catwell/cw-lua/blob/master/sha256/sha256.lua

--- SHA-256 implementation by Pierre 'catwell' Chapuis
--- MIT licensed (see LICENSE.txt)
--- Only works on little endian platforms.

local ffi = require "ffi"
assert(ffi.abi("le"))

local bit = require "bit"
local band, bxor, bnot = bit.band, bit.bxor, bit.bnot
local rshift, rrot = bit.rshift, bit.ror

local ui32_8 = ffi.typeof("uint32_t[8]")
local uchar_32 = ffi.typeof("unsigned char[32]")
local uchar_256 = ffi.typeof("unsigned char[256]")
local uchar_vla = ffi.typeof("unsigned char[?]")

local H = ui32_8(
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
)

local K = ffi.new("uint32_t[64]",
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
)

local tohex = function(buf, n)
local x = ffi.cast("unsigned char*", buf)
local t = {}
for i=0,n-1 do t[i+1] = string.format("%02x", x[i]) end
return table.concat(t)
end

local write_i64 = function(buf, n)
local q, r = math.floor(n/2^32), n%(2^32)
local j = 0
for i=24,0,-8 do
buf[j] = band(rshift(q, i), 0xff)
j = j + 1
end
for i=24,0,-8 do
buf[j] = band(rshift(r, i), 0xff)
j = j + 1
end
end

local pad = function(msg) -- lua string -> buffer
local l0 = #msg
local n0 = ((56 - (l0 + 1)) % 64)
local l1 = l0 + 1 + n0 + 8
assert(l1%64 == 0, l1)
local nchunks = l1/64
local r = uchar_vla(l1)
ffi.copy(r, msg, l0)
r[l0] = 0x80
write_i64(r+l1-8, 8*l0)
return r, nchunks
end

local rcopy_32 = function(to, from, n) -- copy n 4B chunks, reversing them
local x = ffi.cast("unsigned char*", to)
local y = ffi.cast("unsigned char*", from)
for i=0,n-1 do
for j=0,3 do
x[4*i+j] = y[4*i+3-j]
end
end
end

local transform = function(chunk, res) -- works on a 64B chunk
local a, b, c, d = res[0], res[1], res[2], res[3]
local e, f, g, h = res[4], res[5], res[6], res[7]
local buf = uchar_256()
rcopy_32(buf, chunk,16)
local w = ffi.cast("uint32_t *", buf)
local t1, t2, s0, s1
for i=16,64-1 do
t1, t2 = w[i-15], w[i-2]
s0 = bxor( rrot(t1, 07), rrot(t1, 18), rshift(t1, 03) )
s1 = bxor( rrot(t2, 17), rrot(t2, 19), rshift(t2, 10) )
w[i] = w[i-16] + w[i-7] + s0 + s1 + 0LL -- 32 bit LuaJIT, see issue 1
end
for i=0,64-1 do
s0 = bxor( rrot(a, 02), rrot(a, 13), rrot(a, 22) )
s1 = bxor( rrot(e, 06), rrot(e, 11), rrot(e, 25) )
t1 = s1 + h + w[i] + K[i] + bxor( band(e, f), band(bnot(e), g) )
t2 = s0 + bxor( band(a ,b), band(a, c), band(b, c) )
a, b, c, d, e, f, g, h = t1+t2, a, b, c, d+t1, e, f, g
end
res[0], res[1], res[2], res[3] = res[0]+a, res[1]+b, res[2]+c, res[3]+d
res[4], res[5], res[6], res[7] = res[4]+e, res[5]+f, res[6]+g, res[7]+h
end

local sha256_calc = function(input)
local chunks, nchunks = pad(input)
local buf = ui32_8()
ffi.copy(buf, H, ffi.sizeof(ui32_8))
for i=0,nchunks-1 do transform(chunks+64*i, buf) end
local res = uchar_32()
rcopy_32(res, buf, 8)
return res
end

local sha256_bytes = function(input)
local r = sha256_calc(input)
return ffi.string(r, 32)
end

local sha256_hex = function(input)
local r = sha256_calc(input)
return tohex(r, 32)
end

return {
bytes = sha256_bytes,
hex = sha256_hex,
}

0 comments on commit 650ad2a

Please sign in to comment.