Skip to content

Commit

Permalink
Make the Updater's multibot state thread safe (#98)
Browse files Browse the repository at this point in the history
* Add a proper threadsafe bot mapping solution so that we can add+remove bots on the fly

* Make sure the updater is in charge of stopping bots, not botmappings

* Make bot stopping a method on botdata rather than a function

* Make sure to return errors if trying to add a bot twice

* Bots should be removable by token rather than bot object

* Add benchmarks and tests for the updater and dispatcher

* Add missing if statement on error handling

* Fix bad test

* Name tests
  • Loading branch information
PaulSonOfLars authored Sep 2, 2023
1 parent 5758ff1 commit eab7bce
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 46 deletions.
97 changes: 97 additions & 0 deletions ext/botmapping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package ext

import (
"encoding/json"
"errors"
"sync"

"github.com/PaulSonOfLars/gotgbot/v2"
)

// botData Keeps track of the necessary update channels for each gotgbot.Bot.
type botData struct {
// bot represents the bot for which this data is relevant.
bot *gotgbot.Bot
// updateChan represents the incoming updates channel.
updateChan chan json.RawMessage
// polling allows us to close the polling loop.
polling chan struct{}
// urlPath defines the incoming webhook URL path for this bot.
urlPath string
}

// botMapping Ensures that all botData is stored in a thread-safe manner.
type botMapping struct {
// mapping keeps track of the data required for each bot. The key is the bot token.
mapping map[string]botData
// mux attempts to keep the botMapping data concurrency-safe.
mux sync.RWMutex
}

var ErrBotAlreadyExists = errors.New("bot already exists in bot mapping")

// addBot Adds a new bot to the botMapping structure.
func (m *botMapping) addBot(b *gotgbot.Bot, updateChan chan json.RawMessage, pollChan chan struct{}, urlPath string) error {
m.mux.Lock()
defer m.mux.Unlock()

if m.mapping == nil {
m.mapping = make(map[string]botData)
}

if _, ok := m.mapping[b.Token]; ok {
return ErrBotAlreadyExists
}

m.mapping[b.Token] = botData{
bot: b,
updateChan: updateChan,
polling: pollChan,
urlPath: urlPath,
}
return nil
}

func (m *botMapping) removeBot(token string) (botData, bool) {
m.mux.Lock()
defer m.mux.Unlock()

bData, ok := m.mapping[token]
if !ok {
return botData{}, false
}

delete(m.mapping, token)
return bData, true
}

func (m *botMapping) removeAllBots() []botData {
m.mux.Lock()
defer m.mux.Unlock()

bots := make([]botData, 0, len(m.mapping))
for key, bData := range m.mapping {
bots = append(bots, bData)
delete(m.mapping, key)
}
return bots
}

func (m *botMapping) getBots() []botData {
m.mux.RLock()
defer m.mux.RUnlock()

bots := make([]botData, 0, len(m.mapping))
for _, bData := range m.mapping {
bots = append(bots, bData)
}
return bots
}

func (m *botMapping) getBot(token string) (botData, bool) {
m.mux.RLock()
defer m.mux.RUnlock()

bData, ok := m.mapping[token]
return bData, ok
}
79 changes: 79 additions & 0 deletions ext/botmapping_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ext

import (
"encoding/json"
"testing"

"github.com/PaulSonOfLars/gotgbot/v2"
)

func Test_botMapping(t *testing.T) {
bm := botMapping{}
b := &gotgbot.Bot{
User: gotgbot.User{},
Token: "SOME_TOKEN",
BotClient: &gotgbot.BaseBotClient{},
}

updateChan := make(chan json.RawMessage)
pollChan := make(chan struct{})

t.Run("addBot", func(t *testing.T) {
// check that bots can be added fine
err := bm.addBot(b, updateChan, pollChan, "")
if err != nil {
t.Errorf("expected to be able to add a new bot fine: %s", err.Error())
t.FailNow()
}
if len(bm.getBots()) != 1 {
t.Errorf("expected 1 bot, got %d", len(bm.getBots()))
t.FailNow()
}
})

t.Run("doubleAdd", func(t *testing.T) {
// Adding the same bot twice should fail
err := bm.addBot(b, updateChan, pollChan, "")
if err == nil {
t.Errorf("adding the same bot twice should throw an error")
t.FailNow()
}
if len(bm.getBots()) != 1 {
t.Errorf("expected only haveing 1 bot when adding a duplicate, but got %d", len(bm.getBots()))
t.FailNow()
}
})

t.Run("getBot", func(t *testing.T) {
// check that bot data is correct
bdata, ok := bm.getBot(b.Token)
if !ok {
t.Errorf("failed to get bot with token %s", b.Token)
t.FailNow()
}
if bdata.polling != pollChan {
t.Errorf("polling channel was not the same")
t.FailNow()
}
if bdata.updateChan != updateChan {
t.Errorf("update channel was not the same")
t.FailNow()
}
})

t.Run("removeBot", func(t *testing.T) {
// check that bot cant be removed
_, ok := bm.removeBot(b.Token)
if !ok {
t.Errorf("failed to remove bot with token %s", b.Token)
t.FailNow()
}

_, ok = bm.getBot(b.Token)
if ok {
t.Errorf("bot with token %s should be gone", b.Token)
t.FailNow()
}
})

}
30 changes: 30 additions & 0 deletions ext/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package ext

import (
"github.com/PaulSonOfLars/gotgbot/v2"
)

type DummyHandler struct {
F func(b *gotgbot.Bot, ctx *Context) error
}

func (d DummyHandler) CheckUpdate(b *gotgbot.Bot, ctx *Context) bool {
return true
}

func (d DummyHandler) HandleUpdate(b *gotgbot.Bot, ctx *Context) error {
return d.F(b, ctx)
}

func (d DummyHandler) Name() string {
return "dummy"
}

func (u *Updater) InjectUpdate(token string, upd gotgbot.Update) error {
bData, ok := u.botMapping.getBot(token)
if !ok {
return ErrNotFound
}

return u.Dispatcher.ProcessUpdate(bData.bot, &upd, nil)
}
30 changes: 30 additions & 0 deletions ext/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package ext

import (
"encoding/json"
"sync"
"testing"
"time"

"github.com/PaulSonOfLars/gotgbot/v2"
)

func TestDispatcherStop(t *testing.T) {
Expand Down Expand Up @@ -56,3 +59,30 @@ func TestUnlimitedDispatcherStop(t *testing.T) {
go d.Start(nil, make(chan json.RawMessage))
d.Stop() // ensure no panics
}

func BenchmarkDispatcher(b *testing.B) {
d := NewDispatcher(nil)

wg := sync.WaitGroup{}
d.AddHandler(DummyHandler{F: func(b *gotgbot.Bot, ctx *Context) error {
wg.Done()
return nil
}})

updateChan := make(chan json.RawMessage)

go d.Start(&gotgbot.Bot{}, updateChan)

upd, err := json.Marshal(gotgbot.Update{Message: &gotgbot.Message{Text: "test"}})
if err != nil {
b.Fatalf("failed to marshal test msg: %s", err.Error())
}

for i := 0; i < b.N; i++ {
wg.Add(1)
go func() { updateChan <- upd }()
}

wg.Wait()
d.Stop()
}
Loading

0 comments on commit eab7bce

Please sign in to comment.