From 616c53736b5ba767048b8e729c06c84296ab55f7 Mon Sep 17 00:00:00 2001 From: Paul Larsen Date: Sun, 3 Nov 2024 19:43:16 +0000 Subject: [PATCH] Extend ext.Context to store bot information (#198) * Add the bot's userinfo to the context struct to ensure we have all the necessary information to determine update ownership at runtime * Use botID instead of full bot info * Improve overall data --- bot.go | 27 ++++++++++++- ext/context.go | 8 +++- ext/dispatcher.go | 2 +- ext/dispatcher_ext_test.go | 2 +- ext/handlers/common_test.go | 14 +++---- ext/handlers/conversation/key_strategies.go | 7 ++-- ext/handlers/conversation_test.go | 45 +++++++++++++-------- 7 files changed, 71 insertions(+), 34 deletions(-) diff --git a/bot.go b/bot.go index e029ed24..7f6ab2fd 100644 --- a/bot.go +++ b/bot.go @@ -6,11 +6,18 @@ import ( "errors" "fmt" "net/http" + "strconv" + "strings" "time" ) //go:generate go run ./scripts/generate +var ( + ErrNilBotClient = errors.New("nil BotClient") + ErrInvalidTokenFormat = errors.New("invalid token format") +) + // Bot is the default Bot struct used to send and receive messages to the telegram API. type Bot struct { // Token stores the bot's secret token obtained from t.me/BotFather, and used to interact with telegram's API. @@ -76,6 +83,24 @@ func NewBot(token string, opts *BotOpts) (*Bot, error) { return nil, fmt.Errorf("failed to check bot token: %w", err) } b.User = *botUser + } else { + // If token checks are disabled, we populate the bot's ID from the token. + split := strings.Split(token, ":") + if len(split) != 2 { + return nil, fmt.Errorf("%w: expected '123:abcd', got %s", ErrInvalidTokenFormat, token) + } + + id, err := strconv.ParseInt(split[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse bot ID from token: %w", err) + } + b.User = User{ + Id: id, + IsBot: true, + // We mark these fields as missing so we can know why they're not available + FirstName: "", + Username: "", + } } return &b, nil @@ -89,8 +114,6 @@ func (bot *Bot) UseMiddleware(mw func(client BotClient) BotClient) *Bot { return bot } -var ErrNilBotClient = errors.New("nil BotClient") - func (bot *Bot) Request(method string, params map[string]string, data map[string]FileReader, opts *RequestOpts) (json.RawMessage, error) { return bot.RequestWithContext(context.Background(), method, params, data, opts) } diff --git a/ext/context.go b/ext/context.go index cda02969..2a55152a 100644 --- a/ext/context.go +++ b/ext/context.go @@ -10,6 +10,9 @@ import ( type Context struct { // gotgbot.Update is inlined so that we can access all fields immediately if necessary. *gotgbot.Update + // Bot represents gotgbot.User behind the Bot that received this update, so we can keep track of update ownership. + // Note: this information may be incomplete in the case where token validation is disabled. + Bot gotgbot.User // Data represents update-local storage. // This can be used to pass data across handlers - for example, to cache operations relevant to the current update, // such as admin checks. @@ -35,9 +38,9 @@ type Context struct { EffectiveSender *gotgbot.Sender } -// NewContext populates a context with the relevant fields from the current update. +// NewContext populates a context with the relevant fields from the current bot and update. // It takes a data field in the case where custom data needs to be passed. -func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context { +func NewContext(b *gotgbot.Bot, update *gotgbot.Update, data map[string]interface{}) *Context { var msg *gotgbot.Message var chat *gotgbot.Chat var user *gotgbot.User @@ -162,6 +165,7 @@ func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context { return &Context{ Update: update, + Bot: b.User, Data: data, EffectiveMessage: msg, EffectiveChat: chat, diff --git a/ext/dispatcher.go b/ext/dispatcher.go index b7a063e0..5074ad38 100644 --- a/ext/dispatcher.go +++ b/ext/dispatcher.go @@ -268,7 +268,7 @@ func (d *Dispatcher) processRawUpdate(b *gotgbot.Bot, r json.RawMessage) error { // ProcessUpdate iterates over the list of groups to execute the matching handlers. // This is also where we recover from any panics that are thrown by user code, to avoid taking down the bot. func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[string]interface{}) (err error) { - ctx := NewContext(u, data) + ctx := NewContext(b, u, data) defer func() { if r := recover(); r != nil { diff --git a/ext/dispatcher_ext_test.go b/ext/dispatcher_ext_test.go index 6a122e8f..0f5e85f4 100644 --- a/ext/dispatcher_ext_test.go +++ b/ext/dispatcher_ext_test.go @@ -100,7 +100,7 @@ func TestDispatcher(t *testing.T) { } t.Log("Processing one update...") - err := d.ProcessUpdate(nil, &gotgbot.Update{ + err := d.ProcessUpdate(&gotgbot.Bot{}, &gotgbot.Update{ Message: &gotgbot.Message{Text: "test text"}, }, nil) if err != nil { diff --git a/ext/handlers/common_test.go b/ext/handlers/common_test.go index b876b1f1..1cd2a1ef 100644 --- a/ext/handlers/common_test.go +++ b/ext/handlers/common_test.go @@ -17,7 +17,7 @@ func NewTestBot() *gotgbot.Bot { return &gotgbot.Bot{ Token: "use-me", User: gotgbot.User{ - Id: 0, + Id: rand.Int63(), IsBot: false, FirstName: "gobot", LastName: "", @@ -33,13 +33,13 @@ func NewTestBot() *gotgbot.Bot { } } -func NewMessage(userId int64, chatId int64, message string) *ext.Context { - return newMessage(userId, chatId, message, nil) +func NewMessage(b *gotgbot.Bot, userId int64, chatId int64, message string) *ext.Context { + return newMessage(b, userId, chatId, message, nil) } -func NewCommandMessage(userId int64, chatId int64, command string, args []string) *ext.Context { +func NewCommandMessage(b *gotgbot.Bot, userId int64, chatId int64, command string, args []string) *ext.Context { msg, ents := buildCommand(command, args) - return newMessage(userId, chatId, msg, ents) + return newMessage(b, userId, chatId, msg, ents) } func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) { @@ -53,13 +53,13 @@ func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) { } } -func newMessage(userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context { +func newMessage(b *gotgbot.Bot, userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context { chatType := "supergroup" if userId == chatId { chatType = "private" } - return ext.NewContext(&gotgbot.Update{ + return ext.NewContext(b, &gotgbot.Update{ UpdateId: rand.Int63(), // should this be consistent? Message: &gotgbot.Message{ MessageId: rand.Int63(), // should this be consistent? diff --git a/ext/handlers/conversation/key_strategies.go b/ext/handlers/conversation/key_strategies.go index 472c63a1..b8f9b9b6 100644 --- a/ext/handlers/conversation/key_strategies.go +++ b/ext/handlers/conversation/key_strategies.go @@ -3,7 +3,6 @@ package conversation import ( "errors" "fmt" - "strconv" "github.com/PaulSonOfLars/gotgbot/v2/ext" ) @@ -27,7 +26,7 @@ func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) { if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil { return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey) } - return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil + return fmt.Sprintf("%d/%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil } // KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats. @@ -35,7 +34,7 @@ func KeyStrategySender(ctx *ext.Context) (string, error) { if ctx.EffectiveSender == nil { return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey) } - return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil + return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id()), nil } // KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together. @@ -43,7 +42,7 @@ func KeyStrategyChat(ctx *ext.Context) (string, error) { if ctx.EffectiveChat == nil { return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey) } - return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil + return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveChat.Id), nil } // StateKey provides a sane default for handling incoming updates. diff --git a/ext/handlers/conversation_test.go b/ext/handlers/conversation_test.go index d77881a6..697ac674 100644 --- a/ext/handlers/conversation_test.go +++ b/ext/handlers/conversation_test.go @@ -37,14 +37,14 @@ func TestBasicConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - textMessage := NewMessage(userId, chatId, "message") + textMessage := NewMessage(b, userId, chatId, "message") runHandler(t, b, &conv, textMessage, nextStep, "") if !ended { t.Fatalf("expected the internal handler to have run") @@ -79,8 +79,8 @@ func TestBasicKeyedConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startFromUserOne := NewCommandMessage(userIdOne, chatId, "start", []string{}) - messageFromTwo := NewMessage(userIdTwo, chatId, "message") + startFromUserOne := NewCommandMessage(b, userIdOne, chatId, "start", []string{}) + messageFromTwo := NewMessage(b, userIdTwo, chatId, "message") runHandler(t, b, &conv, startFromUserOne, "", nextStep) @@ -89,6 +89,11 @@ func TestBasicKeyedConversation(t *testing.T) { // But user two doesnt exist checkExpectedState(t, &conv, messageFromTwo, "") + + b2 := NewTestBot() + messageTo2 := NewMessage(b2, userIdOne, chatId, "message") + // And bot two hasn't changed either + checkExpectedState(t, &conv, messageTo2, "") } func TestBasicConversationExit(t *testing.T) { @@ -121,14 +126,14 @@ func TestBasicConversationExit(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint, and starting the conversation. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "cancel" command, triggering the exitpoint, and immediately ending the conversation. - cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{}) + cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{}) runHandler(t, b, &conv, cancelCommand, nextStep, "") if !ended { t.Fatalf("expected the cancel command to have run") @@ -138,7 +143,7 @@ func TestBasicConversationExit(t *testing.T) { checkExpectedState(t, &conv, cancelCommand, "") // Emulate sending the "message" text, which now should not interact with the conversation. - textMessage := NewMessage(userId, chatId, "message") + textMessage := NewMessage(b, userId, chatId, "message") if conv.CheckUpdate(b, textMessage) { t.Fatalf("did not expect the internal handler to run") } @@ -177,14 +182,14 @@ func TestFallbackConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "cancel" command, triggering the fallback handler (and causing it to "end"). - cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{}) + cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{}) runHandler(t, b, &conv, cancelCommand, nextStep, "") if !fallback { t.Fatalf("expected the fallback handler to have run") @@ -220,14 +225,14 @@ func TestReEntryConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if startCount != 1 { t.Fatalf("expected the entrypoint handler to have run") } // Send a message which matches both the entrypoint, and the "nextStep" state. - cancelCommand := NewCommandMessage(userId, chatId, "start", []string{"message"}) + cancelCommand := NewCommandMessage(b, userId, chatId, "start", []string{"message"}) runHandler(t, b, &conv, cancelCommand, nextStep, nextStep) // Should hit if startCount != 2 { t.Fatalf("expected the entrypoint handler to have run a second time") @@ -285,20 +290,20 @@ func TestNestedConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - start := NewCommandMessage(userId, chatId, startCmd, []string{}) + start := NewCommandMessage(b, userId, chatId, startCmd, []string{}) runHandler(t, b, &conv, start, "", firstStep) // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - textMessage := NewMessage(userId, chatId, messageText) + textMessage := NewMessage(b, userId, chatId, messageText) runHandler(t, b, &conv, textMessage, firstStep, secondStep) // Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation. - nestedStart := NewCommandMessage(userId, chatId, nestedStartCmd, []string{}) + nestedStart := NewCommandMessage(b, userId, chatId, nestedStartCmd, []string{}) willRunHandler(t, b, &nestedConv, nestedStart, "") runHandler(t, b, &conv, nestedStart, secondStep, secondStep) // Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation. - nestedFinish := NewMessage(userId, chatId, finishNestedText) + nestedFinish := NewMessage(b, userId, chatId, finishNestedText) willRunHandler(t, b, &nestedConv, nestedFinish, nestedStep) runHandler(t, b, &conv, nestedFinish, secondStep, thirdStep) @@ -307,7 +312,7 @@ func TestNestedConversation(t *testing.T) { t.Log("Nested conversation finished") // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - finish := NewMessage(userId, chatId, finishText) + finish := NewMessage(b, userId, chatId, finishText) runHandler(t, b, &conv, finish, thirdStep, "") checkExpectedState(t, &conv, textMessage, "") @@ -329,7 +334,7 @@ func TestEmptyKeyConversation(t *testing.T) { ) // Run an empty - pollUpd := ext.NewContext(&gotgbot.Update{ + pollUpd := ext.NewContext(b, &gotgbot.Update{ UpdateId: rand.Int63(), // should this be consistent? Poll: &gotgbot.Poll{ Id: "some_id", @@ -358,6 +363,8 @@ func TestEmptyKeyConversation(t *testing.T) { // runHandler ensures that the incoming update will trigger the conversation. func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, currentState string, nextState string) { + t.Helper() + willRunHandler(t, b, conv, message, currentState) if err := conv.HandleUpdate(b, message); err != nil { t.Fatalf("unexpected error from handler: %s", err.Error()) @@ -368,6 +375,8 @@ func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, messa // willRunHandler ensures that the incoming update will trigger the conversation. func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, expectedState string) { + t.Helper() + t.Logf("conv %p: checking message for %d in %d with text: %s", conv, message.EffectiveSender.Id(), message.EffectiveChat.Id, message.Message.Text) checkExpectedState(t, conv, message, expectedState) @@ -378,6 +387,8 @@ func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, m } func checkExpectedState(t *testing.T, conv *handlers.Conversation, message *ext.Context, nextState string) { + t.Helper() + currentState, err := conv.StateStorage.Get(message) if err != nil { if nextState == "" && errors.Is(err, conversation.ErrKeyNotFound) {