Skip to content

Commit

Permalink
feat: add session and ticket store implementation (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
nic-chen authored Oct 25, 2022
1 parent 5c5c57b commit 36891af
Show file tree
Hide file tree
Showing 8 changed files with 666 additions and 0 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Test

on:
push:
branches:
- main
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v2
- name: Start etcd
run: |
docker run -p 2379:2379 -d -e ALLOW_NONE_AUTHENTICATION=yes --name etcd bitnami/etcd
- name: Run Test
run: go test --cover -v
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

.idea

*.test

83 changes: 83 additions & 0 deletions etcd_session_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package casstore

import (
"context"
"fmt"

clientv3 "go.etcd.io/etcd/client/v3"
"gopkg.in/cas.v2"
)

var _ cas.SessionStore = &etcdSessionStore{}

// NewEtcdSessionStore create a session store using etcd.
func NewEtcdSessionStore(config clientv3.Config, ctx context.Context,
prefix string, maxAge int64) (cas.SessionStore, error) {
client, err := clientv3.New(config)
if err != nil {
return nil, err
}

if prefix == "" {
prefix = "/cas/sessions"
}

if maxAge == 0 {
maxAge = 86400
}

return &etcdSessionStore{
cli: client,
ctx: ctx,
prefix: prefix,
maxAge: maxAge,
}, nil
}

type etcdSessionStore struct {
cli *clientv3.Client
ctx context.Context
prefix string
maxAge int64
}

func (s *etcdSessionStore) Get(sessionID string) (string, bool) {
key := s.prefix + "/" + sessionID
resp, err := s.cli.Get(s.ctx, key)
if err != nil {
return "", false
}

if resp.Count == 0 {
return "", false
}

return string(resp.Kvs[0].Value), true
}

func (s *etcdSessionStore) Set(sessionID, ticket string) error {
key := s.prefix + "/" + sessionID

grant, err := s.cli.Grant(s.ctx, s.maxAge+1)
if err != nil {
return err
}

_, err = s.cli.Put(s.ctx, key, ticket, clientv3.WithLease(grant.ID))

return err
}

func (s *etcdSessionStore) Delete(sessionID string) error {
key := s.prefix + "/" + sessionID
resp, err := s.cli.Delete(s.ctx, key)
if err != nil {
return err
}

if resp.Deleted == 0 {
return fmt.Errorf("key: %s is not found in etcd", key)
}

return nil
}
91 changes: 91 additions & 0 deletions etcd_session_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package casstore

import (
"context"
"testing"

"github.com/stretchr/testify/require"
clientv3 "go.etcd.io/etcd/client/v3"
"gopkg.in/cas.v2"
)

var (
store cas.SessionStore
_defaultEtcd = "http://127.0.0.1:2379"
)

func init() {
store, _ = NewEtcdSessionStore(clientv3.Config{Endpoints: []string{_defaultEtcd}},
context.Background(), "/cas/sessions", 3600)
}

func TestSessionStore_Get(t *testing.T) {

v, ok := store.Get("key1")
require.False(t, ok)
require.Equal(t, "", v)

err := store.Set("key1", "value1")
require.Nil(t, err)

v, ok = store.Get("key1")
require.True(t, ok)
require.Equal(t, "value1", v)
}

func TestSessionStore_Set(t *testing.T) {
err := store.Set("key1", "value1")
require.Nil(t, err)

err = store.Set("key2", "value2")
require.Nil(t, err)

v, ok := store.Get("key1")
require.True(t, ok)
require.Equal(t, "value1", v)

v, ok = store.Get("key2")
require.True(t, ok)
require.Equal(t, "value2", v)

err = store.Set("key2", "value2-new")
require.Nil(t, err)

v, ok = store.Get("key2")
require.True(t, ok)
require.Equal(t, "value2-new", v)
}

func TestSessionStore_Delete(t *testing.T) {
err := store.Set("key1", "value1")
require.Nil(t, err)

err = store.Set("key2", "value2")
require.Nil(t, err)

v, ok := store.Get("key1")
require.True(t, ok)
require.Equal(t, "value1", v)

v, ok = store.Get("key2")
require.True(t, ok)
require.Equal(t, "value2", v)

err = store.Delete("key2")
require.Nil(t, err)

v, ok = store.Get("key1")
require.True(t, ok)
require.Equal(t, "value1", v)

v, ok = store.Get("key2")
require.False(t, ok)
require.Equal(t, "", v)

err = store.Delete("key1")
require.Nil(t, err)

v, ok = store.Get("key1")
require.False(t, ok)
require.Equal(t, "", v)
}
107 changes: 107 additions & 0 deletions etcd_ticket_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package casstore

import (
"context"
"encoding/json"
"fmt"

clientv3 "go.etcd.io/etcd/client/v3"
"gopkg.in/cas.v2"
)

var _ cas.TicketStore = &etcdTicketStore{}

// NewEtcdTicketStore create a ticket store using etcd.
func NewEtcdTicketStore(config clientv3.Config, ctx context.Context,
prefix string, maxAge int64) (cas.TicketStore, error) {
client, err := clientv3.New(config)
if err != nil {
return nil, err
}
if prefix == "" {
prefix = "/cas/tickets"
}
if maxAge == 0 {
maxAge = 3600
}

return &etcdTicketStore{
cli: client,
ctx: ctx,
prefix: prefix,
maxAge: maxAge,
}, nil
}

// etcdTicketStore implements the TicketStore interface storing ticket data in etcd.
type etcdTicketStore struct {
cli *clientv3.Client
ctx context.Context
prefix string
maxAge int64
}

// Read returns the AuthenticationResponse for a ticket
func (s *etcdTicketStore) Read(id string) (*cas.AuthenticationResponse, error) {
key := s.prefix + "/" + id
resp, err := s.cli.Get(s.ctx, key)
if err != nil {
return nil, cas.ErrInvalidTicket
}
if resp.Count == 0 {
return nil, cas.ErrInvalidTicket
}

var rsp *cas.AuthenticationResponse
err = json.Unmarshal(resp.Kvs[0].Value, &rsp)
if err != nil {
return nil, cas.ErrInvalidTicket
}

return rsp, nil
}

// Write stores the AuthenticationResponse for a ticket
func (s *etcdTicketStore) Write(id string, ticket *cas.AuthenticationResponse) error {
key := s.prefix + "/" + id
grant, err := s.cli.Grant(s.ctx, s.maxAge+1)
if err != nil {
return err
}
data, err := json.Marshal(ticket)
if err != nil {
return err
}
_, err = s.cli.Put(s.ctx, key, string(data), clientv3.WithLease(grant.ID))
return err
}

// Delete removes the AuthenticationResponse for a ticket
func (s *etcdTicketStore) Delete(id string) error {
key := s.prefix + "/" + id
resp, err := s.cli.Delete(s.ctx, key)
if err != nil {
return err
}

if resp.Deleted == 0 {
return fmt.Errorf("key: %s is not found in etcd", key)
}

return nil
}

// Clear removes all ticket data
func (s *etcdTicketStore) Clear() error {
key := s.prefix + "/"
resp, err := s.cli.Delete(s.ctx, key, clientv3.WithPrefix())
if err != nil {
return err
}

if resp.Deleted == 0 {
return fmt.Errorf("key: %s is not found in etcd", key)
}

return nil
}
55 changes: 55 additions & 0 deletions etcd_ticket_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package casstore

import (
"context"
"reflect"
"testing"

clientv3 "go.etcd.io/etcd/client/v3"
"gopkg.in/cas.v2"
)

func TestEtcdTicketStore(t *testing.T) {
user1 := &cas.AuthenticationResponse{User: "user1"}
user2 := &cas.AuthenticationResponse{User: "user2"}
store, err := NewEtcdTicketStore(clientv3.Config{Endpoints: []string{_defaultEtcd}},
context.Background(), "/cas/tickets", 3600)

if err := store.Write("user1", user1); err != nil {
t.Errorf("Expected store.Write(user1) to succeed, got error: %v", err)
}

if err := store.Write("user2", user2); err != nil {
t.Errorf("Expected store.Write(user2) to succeed, got error: %v", err)
}

ar, err := store.Read("user2")
if err != nil {
t.Errorf("Expected store.Read(user2) to succeed, got error: %v", err)
}

if !reflect.DeepEqual(*ar, *user2) {
t.Errorf("Expected retrieved AuthenticationResponse to be %v, got %v", *user2, *ar)
}

if err := store.Delete("user2"); err != nil {
t.Errorf("Error while deleting user2, got %v", err)
}

if _, err := store.Read("user2"); err != cas.ErrInvalidTicket {
t.Errorf("Expected store.Read(user2) to fail")
}

if err := store.Clear(); err != nil {
t.Errorf("Expected store.Clear() to succeed, got error: %v", err)
}

_, err = store.Read("user1")
if err == nil {
t.Errorf("Expected an error from store.Read(user1), got nil")
}

if err != cas.ErrInvalidTicket {
t.Errorf("Expected ErrInvalidTicket from store.Read(user1), got %v", err)
}
}
Loading

0 comments on commit 36891af

Please sign in to comment.