Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context aware funcs for all existing api functions #144

Merged
merged 4 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion appservice/appservice_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package appservice

import (
"context"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -35,7 +36,7 @@ func TestClient_UnixSocket(t *testing.T) {
err = as.SetHomeserverURL(fmt.Sprintf("unix://%s", socket))
assert.NoError(t, err)
client := as.Client("user1")
resp, err := client.Whoami()
resp, err := client.Whoami(context.Background())
assert.NoError(t, err)
assert.Equal(t, "@joe:example.org", string(resp.UserID))
}
171 changes: 86 additions & 85 deletions appservice/intent.go

Large diffs are not rendered by default.

34 changes: 18 additions & 16 deletions bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ type Crypto interface {
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, *event.Content) error
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
Start()
Expand Down Expand Up @@ -287,9 +287,9 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) {

var MinSpecVersion = mautrix.SpecV11

func (br *Bridge) ensureConnection() {
func (br *Bridge) ensureConnection(ctx context.Context) {
for {
versions, err := br.Bot.Versions()
versions, err := br.Bot.Versions(ctx)
if err != nil {
br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...")
time.Sleep(10 * time.Second)
Expand All @@ -315,7 +315,7 @@ func (br *Bridge) ensureConnection() {
}
}

resp, err := br.Bot.Whoami()
resp, err := br.Bot.Whoami(ctx)
if err != nil {
if errors.Is(err, mautrix.MUnknownToken) {
br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
Expand Down Expand Up @@ -346,7 +346,7 @@ func (br *Bridge) ensureConnection() {
const maxRetries = 6
for {
txnID = br.Bot.TxnID()
pingResp, err = br.Bot.AppservicePing(br.Config.AppService.ID, txnID)
pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID)
if err == nil {
break
}
Expand Down Expand Up @@ -385,42 +385,42 @@ func (br *Bridge) ensureConnection() {
Msg("Homeserver -> bridge connection works")
}

func (br *Bridge) fetchMediaConfig() {
cfg, err := br.Bot.GetMediaConfig()
func (br *Bridge) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to fetch media config")
} else {
br.MediaConfig = *cfg
}
}

func (br *Bridge) UpdateBotProfile() {
func (br *Bridge) UpdateBotProfile(ctx context.Context) {
br.ZLog.Debug().Msg("Updating bot profile")
botConfig := &br.Config.AppService.Bot

var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" {
err = br.Bot.SetAvatarURL(mxc)
err = br.Bot.SetAvatarURL(ctx, mxc)
} else if !botConfig.ParsedAvatar.IsEmpty() {
err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar)
err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar)
}
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar")
}

if botConfig.Displayname == "remove" {
err = br.Bot.SetDisplayName("")
err = br.Bot.SetDisplayName(ctx, "")
} else if len(botConfig.Displayname) > 0 {
err = br.Bot.SetDisplayName(botConfig.Displayname)
err = br.Bot.SetDisplayName(ctx, botConfig.Displayname)
}
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname")
}

if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" {
br.ZLog.Debug().Msg("Setting contact info on the appservice bot")
br.Bot.BeeperUpdateProfile(map[string]any{
br.Bot.BeeperUpdateProfile(ctx, map[string]any{
"com.beeper.bridge.service": br.BeeperServiceName,
"com.beeper.bridge.network": br.BeeperNetworkName,
"com.beeper.bridge.is_bridge_bot": true,
Expand Down Expand Up @@ -633,8 +633,10 @@ func (br *Bridge) start() {
os.Exit(23)
}
br.ZLog.Debug().Msg("Checking connection to homeserver")
br.ensureConnection()
go br.fetchMediaConfig()

ctx := context.Background()
br.ensureConnection(ctx)
go br.fetchMediaConfig(ctx)

if br.Crypto != nil {
err = br.Crypto.Init()
Expand All @@ -647,7 +649,7 @@ func (br *Bridge) start() {
br.ZLog.Debug().Msg("Starting event processor")
br.EventProcessor.Start()

go br.UpdateBotProfile()
go br.UpdateBotProfile(ctx)
if br.Crypto != nil {
go br.Crypto.Start()
}
Expand Down
3 changes: 2 additions & 1 deletion bridge/commands/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package commands

import (
"context"
"strconv"

"maunium.net/go/mautrix/id"
Expand Down Expand Up @@ -57,7 +58,7 @@ func fnSetPowerLevel(ce *Event) {
ce.Reply("**Usage:** `set-pl [user] <level>`")
return
}
_, err = ce.Portal.MainIntent().SetPowerLevel(ce.RoomID, userID, level)
_, err = ce.Portal.MainIntent().SetPowerLevel(context.Background(), ce.RoomID, userID, level)
if err != nil {
ce.Reply("Failed to set power levels: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion bridge/commands/doublepuppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package commands

import "context"

var CommandLoginMatrix = &FullHandler{
Func: fnLoginMatrix,
Name: "login-matrix",
Expand Down Expand Up @@ -54,7 +56,7 @@ func fnPingMatrix(ce *Event) {
ce.Reply("You are not logged in with your Matrix account.")
return
}
resp, err := puppet.CustomIntent().Whoami()
resp, err := puppet.CustomIntent().Whoami(context.Background())
if err != nil {
ce.Reply("Failed to validate Matrix login: %v", err)
} else {
Expand Down
8 changes: 4 additions & 4 deletions bridge/commands/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,31 @@ func (ce *Event) Reply(msg string, args ...interface{}) {
func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) {
content := format.RenderMarkdown(msg, allowMarkdown, allowHTML)
content.MsgType = event.MsgNotice
_, err := ce.MainIntent().SendMessageEvent(ce.RoomID, event.EventMessage, content)
_, err := ce.MainIntent().SendMessageEvent(context.Background(), ce.RoomID, event.EventMessage, content)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to reply to command")
}
}

// React sends a reaction to the command.
func (ce *Event) React(key string) {
_, err := ce.MainIntent().SendReaction(ce.RoomID, ce.EventID, key)
_, err := ce.MainIntent().SendReaction(context.Background(), ce.RoomID, ce.EventID, key)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to react to command")
}
}

// Redact redacts the command.
func (ce *Event) Redact(req ...mautrix.ReqRedact) {
_, err := ce.MainIntent().RedactEvent(ce.RoomID, ce.EventID, req...)
_, err := ce.MainIntent().RedactEvent(context.Background(), ce.RoomID, ce.EventID, req...)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to redact command")
}
}

// MarkRead marks the command event as read.
func (ce *Event) MarkRead() {
err := ce.MainIntent().SendReceipt(ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil)
err := ce.MainIntent().SendReceipt(context.Background(), ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read")
}
Expand Down
4 changes: 3 additions & 1 deletion bridge/commands/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package commands

import (
"context"

"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event"
Expand Down Expand Up @@ -76,7 +78,7 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool {
}

func (fh *FullHandler) userHasRoomPermission(ce *Event) bool {
levels, err := ce.MainIntent().PowerLevels(ce.RoomID)
levels, err := ce.MainIntent().PowerLevels(context.Background(), ce.RoomID)
if err != nil {
ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels")
ce.Reply("Failed to get room power levels to see if you're allowed to use that command")
Expand Down
23 changes: 13 additions & 10 deletions bridge/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ func (helper *CryptoHelper) Init() error {
}

func (helper *CryptoHelper) resyncEncryptionInfo() {
ctx := context.Background()
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.bridge.DB.Query(`SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
if err != nil {
log.Err(err).Msg("Failed to query rooms for resync")
return
Expand All @@ -158,10 +159,10 @@ func (helper *CryptoHelper) resyncEncryptionInfo() {
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
for _, roomID := range roomIDs {
var evt event.EncryptionEventContent
err = helper.client.StateEvent(roomID, event.StateEncryption, "", &evt)
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
_, err = helper.bridge.DB.Exec(`
_, err = helper.bridge.DB.ExecContext(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
Expand All @@ -182,7 +183,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo() {
Int("max_messages", maxMessages).
Interface("content", &evt).
Msg("Resynced encryption event")
_, err = helper.bridge.DB.Exec(`
_, err = helper.bridge.DB.ExecContext(ctx, `
UPDATE crypto_megolm_inbound_session
SET max_age=$1, max_messages=$2
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
Expand Down Expand Up @@ -223,20 +224,21 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device
}

func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) {
ctx := context.Background()
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
// the Login call will then override the access token in the client.
client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID())
flows, err := client.GetLoginFlows()
flows, err := client.GetLoginFlows(ctx)
if err != nil {
return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err)
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login")
}
resp, err := client.Login(&mautrix.ReqLogin{
resp, err := client.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypeAppservice,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
Expand All @@ -255,8 +257,9 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) {
}

func (helper *CryptoHelper) verifyKeysAreOnServer() {
ctx := context.Background()
helper.log.Debug().Msg("Making sure keys are still on server")
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
helper.client.UserID: {helper.client.DeviceID},
},
Expand Down Expand Up @@ -333,7 +336,7 @@ func (helper *CryptoHelper) Reset(startAfterReset bool) {
helper.log.Debug().Msg("Crypto syncer stopped, clearing database")
helper.clearDatabase()
helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions")
_, err := helper.client.LogoutAll()
_, err := helper.client.LogoutAll(context.Background())
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to log out all devices")
}
Expand Down Expand Up @@ -395,13 +398,13 @@ func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.Sender
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}

func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
helper.lock.RLock()
defer helper.lock.RUnlock()
if deviceID == "" {
deviceID = "*"
}
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
if err != nil {
helper.log.Warn().Err(err).
Str("user_id", userID.String()).
Expand Down
Loading