-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make the Updater's multibot state thread safe (#98)
* Add a proper threadsafe bot mapping solution so that we can add+remove bots on the fly * Make sure the updater is in charge of stopping bots, not botmappings * Make bot stopping a method on botdata rather than a function * Make sure to return errors if trying to add a bot twice * Bots should be removable by token rather than bot object * Add benchmarks and tests for the updater and dispatcher * Add missing if statement on error handling * Fix bad test * Name tests
- Loading branch information
1 parent
5758ff1
commit eab7bce
Showing
7 changed files
with
364 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package ext | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"sync" | ||
|
||
"github.com/PaulSonOfLars/gotgbot/v2" | ||
) | ||
|
||
// botData Keeps track of the necessary update channels for each gotgbot.Bot. | ||
type botData struct { | ||
// bot represents the bot for which this data is relevant. | ||
bot *gotgbot.Bot | ||
// updateChan represents the incoming updates channel. | ||
updateChan chan json.RawMessage | ||
// polling allows us to close the polling loop. | ||
polling chan struct{} | ||
// urlPath defines the incoming webhook URL path for this bot. | ||
urlPath string | ||
} | ||
|
||
// botMapping Ensures that all botData is stored in a thread-safe manner. | ||
type botMapping struct { | ||
// mapping keeps track of the data required for each bot. The key is the bot token. | ||
mapping map[string]botData | ||
// mux attempts to keep the botMapping data concurrency-safe. | ||
mux sync.RWMutex | ||
} | ||
|
||
var ErrBotAlreadyExists = errors.New("bot already exists in bot mapping") | ||
|
||
// addBot Adds a new bot to the botMapping structure. | ||
func (m *botMapping) addBot(b *gotgbot.Bot, updateChan chan json.RawMessage, pollChan chan struct{}, urlPath string) error { | ||
m.mux.Lock() | ||
defer m.mux.Unlock() | ||
|
||
if m.mapping == nil { | ||
m.mapping = make(map[string]botData) | ||
} | ||
|
||
if _, ok := m.mapping[b.Token]; ok { | ||
return ErrBotAlreadyExists | ||
} | ||
|
||
m.mapping[b.Token] = botData{ | ||
bot: b, | ||
updateChan: updateChan, | ||
polling: pollChan, | ||
urlPath: urlPath, | ||
} | ||
return nil | ||
} | ||
|
||
func (m *botMapping) removeBot(token string) (botData, bool) { | ||
m.mux.Lock() | ||
defer m.mux.Unlock() | ||
|
||
bData, ok := m.mapping[token] | ||
if !ok { | ||
return botData{}, false | ||
} | ||
|
||
delete(m.mapping, token) | ||
return bData, true | ||
} | ||
|
||
func (m *botMapping) removeAllBots() []botData { | ||
m.mux.Lock() | ||
defer m.mux.Unlock() | ||
|
||
bots := make([]botData, 0, len(m.mapping)) | ||
for key, bData := range m.mapping { | ||
bots = append(bots, bData) | ||
delete(m.mapping, key) | ||
} | ||
return bots | ||
} | ||
|
||
func (m *botMapping) getBots() []botData { | ||
m.mux.RLock() | ||
defer m.mux.RUnlock() | ||
|
||
bots := make([]botData, 0, len(m.mapping)) | ||
for _, bData := range m.mapping { | ||
bots = append(bots, bData) | ||
} | ||
return bots | ||
} | ||
|
||
func (m *botMapping) getBot(token string) (botData, bool) { | ||
m.mux.RLock() | ||
defer m.mux.RUnlock() | ||
|
||
bData, ok := m.mapping[token] | ||
return bData, ok | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package ext | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
|
||
"github.com/PaulSonOfLars/gotgbot/v2" | ||
) | ||
|
||
func Test_botMapping(t *testing.T) { | ||
bm := botMapping{} | ||
b := &gotgbot.Bot{ | ||
User: gotgbot.User{}, | ||
Token: "SOME_TOKEN", | ||
BotClient: &gotgbot.BaseBotClient{}, | ||
} | ||
|
||
updateChan := make(chan json.RawMessage) | ||
pollChan := make(chan struct{}) | ||
|
||
t.Run("addBot", func(t *testing.T) { | ||
// check that bots can be added fine | ||
err := bm.addBot(b, updateChan, pollChan, "") | ||
if err != nil { | ||
t.Errorf("expected to be able to add a new bot fine: %s", err.Error()) | ||
t.FailNow() | ||
} | ||
if len(bm.getBots()) != 1 { | ||
t.Errorf("expected 1 bot, got %d", len(bm.getBots())) | ||
t.FailNow() | ||
} | ||
}) | ||
|
||
t.Run("doubleAdd", func(t *testing.T) { | ||
// Adding the same bot twice should fail | ||
err := bm.addBot(b, updateChan, pollChan, "") | ||
if err == nil { | ||
t.Errorf("adding the same bot twice should throw an error") | ||
t.FailNow() | ||
} | ||
if len(bm.getBots()) != 1 { | ||
t.Errorf("expected only haveing 1 bot when adding a duplicate, but got %d", len(bm.getBots())) | ||
t.FailNow() | ||
} | ||
}) | ||
|
||
t.Run("getBot", func(t *testing.T) { | ||
// check that bot data is correct | ||
bdata, ok := bm.getBot(b.Token) | ||
if !ok { | ||
t.Errorf("failed to get bot with token %s", b.Token) | ||
t.FailNow() | ||
} | ||
if bdata.polling != pollChan { | ||
t.Errorf("polling channel was not the same") | ||
t.FailNow() | ||
} | ||
if bdata.updateChan != updateChan { | ||
t.Errorf("update channel was not the same") | ||
t.FailNow() | ||
} | ||
}) | ||
|
||
t.Run("removeBot", func(t *testing.T) { | ||
// check that bot cant be removed | ||
_, ok := bm.removeBot(b.Token) | ||
if !ok { | ||
t.Errorf("failed to remove bot with token %s", b.Token) | ||
t.FailNow() | ||
} | ||
|
||
_, ok = bm.getBot(b.Token) | ||
if ok { | ||
t.Errorf("bot with token %s should be gone", b.Token) | ||
t.FailNow() | ||
} | ||
}) | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package ext | ||
|
||
import ( | ||
"github.com/PaulSonOfLars/gotgbot/v2" | ||
) | ||
|
||
type DummyHandler struct { | ||
F func(b *gotgbot.Bot, ctx *Context) error | ||
} | ||
|
||
func (d DummyHandler) CheckUpdate(b *gotgbot.Bot, ctx *Context) bool { | ||
return true | ||
} | ||
|
||
func (d DummyHandler) HandleUpdate(b *gotgbot.Bot, ctx *Context) error { | ||
return d.F(b, ctx) | ||
} | ||
|
||
func (d DummyHandler) Name() string { | ||
return "dummy" | ||
} | ||
|
||
func (u *Updater) InjectUpdate(token string, upd gotgbot.Update) error { | ||
bData, ok := u.botMapping.getBot(token) | ||
if !ok { | ||
return ErrNotFound | ||
} | ||
|
||
return u.Dispatcher.ProcessUpdate(bData.bot, &upd, nil) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.