Skip to content

Commit

Permalink
chore: pass reply_builder explicitly to pubsub module
Browse files Browse the repository at this point in the history
Also, deprecate `reply_builder()` access method.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
  • Loading branch information
romange committed Oct 30, 2024
1 parent daf8604 commit 0fe9e99
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 139 deletions.
9 changes: 1 addition & 8 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,10 @@ class ConnectionContext {
return protocol_;
}

SinkReplyBuilder* reply_builder() {
SinkReplyBuilder* reply_builder_old() {
return rbuilder_.get();
}

// Allows receiving the output data from the commands called from scripts.
SinkReplyBuilder* Inject(SinkReplyBuilder* new_i) {
SinkReplyBuilder* res = rbuilder_.release();
rbuilder_.reset(new_i);
return res;
}

virtual size_t UsedMemory() const;

// connection state / properties.
Expand Down
4 changes: 2 additions & 2 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ void Connection::HandleRequests() {
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder();
reply_builder_ = cc_->reply_builder_old();

if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
Expand Down Expand Up @@ -811,7 +811,7 @@ std::pair<std::string, std::string> Connection::GetClientInfoBeforeAfterTid() co
string_view phase_name = PHASE_NAMES[phase_];

if (cc_) {
DCHECK(cc_->reply_builder() && reply_builder_);
DCHECK(reply_builder_);
string cc_info = service_->GetContextInfo(cc_.get()).Format();
if (reply_builder_->IsSendActive())
phase_name = "send";
Expand Down
151 changes: 73 additions & 78 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ namespace dfly {
using namespace std;
using namespace facade;

static void SendSubscriptionChangedResponse(string_view action, std::optional<string_view> topic,
unsigned count, RedisReplyBuilder* rb) {
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
}

StoredCmd::StoredCmd(const CommandId* cid, ArgSlice args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
Expand Down Expand Up @@ -98,8 +109,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own
}
}

ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb)
ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx)
: facade::ConnectionContext(nullptr, nullptr), transaction{tx} {
if (owner) {
acl_commands = owner->acl_commands;
Expand All @@ -115,8 +125,6 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction
conn_state.db_index = owner->conn_state.db_index;
conn_state.squashing_info = {owner};
}
auto* prev_reply_builder = Inject(crb);
CHECK_EQ(prev_reply_builder, nullptr);
}

void ConnectionContext::ChangeMonitor(bool start) {
Expand All @@ -137,61 +145,13 @@ void ConnectionContext::ChangeMonitor(bool start) {
EnableMonitoring(start);
}

vector<unsigned> ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply,
ConnectionContext* conn) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);

auto& conn_state = conn->conn_state;
if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
conn->subscriptions++;
}

auto& sinfo = *conn->conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : args) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(conn->subscriptions, 1u);
conn->subscriptions--;
}

return result;
}

void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(false, args, to_add, to_reply, this);
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, false, to_add, to_reply);

if (to_reply) {
for (size_t i = 0; i < result.size(); ++i) {
const char* action[2] = {"unsubscribe", "subscribe"};
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action[to_add]);
rb->SendBulkString(ArgS(args, i));
Expand All @@ -200,53 +160,41 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
}

void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(true, args, to_add, to_reply, this);
void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, true, to_add, to_reply);

if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
if (result.size() == 0) {
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0);
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0, rb);
}

for (size_t i = 0; i < result.size(); ++i) {
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i]);
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i], rb);
}
}
}

void ConnectionContext::UnsubscribeAll(bool to_reply) {
void ConnectionContext::UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) {
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0, rb);
}
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
ChangeSubscription(false, to_reply, CmdArgList{arg_vec});
ChangeSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

void ConnectionContext::PUnsubscribeAll(bool to_reply) {
void ConnectionContext::PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) {
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0, rb);
}

StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
ChangePSubscription(false, to_reply, CmdArgList{arg_vec});
}

void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
std::optional<string_view> topic,
unsigned count) {
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
ChangePSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

size_t ConnectionState::ExecInfo::UsedMemory() const {
Expand All @@ -269,6 +217,53 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}

vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);

if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
subscriptions++;
}

auto& sinfo = *conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, this, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : channels) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
subscriptions--;
}

return result;
}

void ConnectionState::ExecInfo::Clear() {
DCHECK(!preborrowed_interpreter); // Must have been released properly
state = EXEC_INACTIVE;
Expand Down
19 changes: 10 additions & 9 deletions src/server/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ struct ConnectionState {
class ConnectionContext : public facade::ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred);

ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb);
ConnectionContext(const ConnectionContext* owner, Transaction* tx);

struct DebugInfo {
uint32_t shards_count = 0;
Expand All @@ -292,10 +290,13 @@ class ConnectionContext : public facade::ConnectionContext {
return conn_state.db_index;
}

void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args);
void UnsubscribeAll(bool to_reply);
void PUnsubscribeAll(bool to_reply);
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);

void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);
void UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void ChangeMonitor(bool start); // either start or stop monitor on a given connection

size_t UsedMemory() const override;
Expand All @@ -317,8 +318,8 @@ class ConnectionContext : public facade::ConnectionContext {
monitor = enable;
}

void SendSubscriptionChangedResponse(std::string_view action,
std::optional<std::string_view> topic, unsigned count);
std::vector<unsigned> ChangeSubscriptions(CmdArgList channels, bool pattern, bool to_add,
bool to_reply);
};

} // namespace dfly
3 changes: 1 addition & 2 deletions src/server/debugcmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool

absl::InlinedVector<string_view, 5> args_view;
facade::CapturingReplyBuilder crb;
ConnectionContext local_cntx{cntx, stub_tx.get(), &crb};
ConnectionContext local_cntx{cntx, stub_tx.get()};

absl::InsecureBitGen gen;
for (unsigned i = 0; i < batch.sz; ++i) {
Expand All @@ -175,7 +175,6 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
sf->service().InvokeCmd(cid, args_span, &crb, &local_cntx);
}

local_cntx.Inject(nullptr);
local_tx->UnlockMulti();
}

Expand Down
5 changes: 1 addition & 4 deletions src/server/journal/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ template <typename... Ts> journal::ParsedEntry::CmdData BuildFromParts(Ts... par
} // namespace

JournalExecutor::JournalExecutor(Service* service)
: service_{service},
reply_builder_{facade::ReplyMode::NONE},
conn_context_{nullptr, nullptr, &reply_builder_} {
: service_{service}, reply_builder_{facade::ReplyMode::NONE}, conn_context_{nullptr, nullptr} {
conn_context_.is_replicating = true;
conn_context_.journal_emulated = true;
conn_context_.skip_acl_validation = true;
conn_context_.ns = &namespaces.GetDefaultNamespace();
}

JournalExecutor::~JournalExecutor() {
conn_context_.Inject(nullptr);
}

void JournalExecutor::Execute(DbIndex dbid, absl::Span<journal::ParsedEntry::CmdData> cmds) {
Expand Down
Loading

0 comments on commit 0fe9e99

Please sign in to comment.