Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RedisConnection destruction. #1618

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions nosql_lib/redis/src/RedisConnection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ RedisConnection::RedisConnection(const trantor::InetAddress &serverAddress,
loop_->queueInLoop([this]() { startConnectionInLoop(); });
}

struct ControlBlock
{
std::weak_ptr<RedisConnection> weakConn;
std::shared_ptr<trantor::Channel> channel;
trantor::InetAddress addr;
};

RedisConnection::~RedisConnection()
{
LOG_TRACE << "status_: " << (int)status_;
if (redisContext_ && status_ != ConnectStatus::kEnd)
{
auto context = redisContext_;
redisContext_ = nullptr;
loop_->queueInLoop([context]() { redisAsyncDisconnect(context); });
}
}

void RedisConnection::startConnectionInLoop()
{
loop_->assertInLoopThread();
Expand All @@ -62,19 +80,32 @@ void RedisConnection::startConnectionInLoop()
return;
}

channel_ = std::make_shared<trantor::Channel>(loop_, redisContext_->c.fd);
channel_->setReadCallback([this]() { handleRedisRead(); });
channel_->setWriteCallback([this]() { handleRedisWrite(); });

auto *cb = new ControlBlock();
cb->weakConn = weak_from_this();
cb->channel = channel_;
cb->addr = serverAddr_;

redisContext_->ev.addWrite = addWrite;
redisContext_->ev.delWrite = delWrite;
redisContext_->ev.addRead = addRead;
redisContext_->ev.delRead = delRead;
redisContext_->ev.cleanup = cleanup;
redisContext_->ev.data = this;
redisContext_->ev.data = cb;

channel_ = std::make_unique<trantor::Channel>(loop_, redisContext_->c.fd);
channel_->setReadCallback([this]() { handleRedisRead(); });
channel_->setWriteCallback([this]() { handleRedisWrite(); });
redisAsyncSetConnectCallback(
redisContext_, [](const redisAsyncContext *context, int status) {
auto thisPtr = static_cast<RedisConnection *>(context->ev.data);
auto *cb = static_cast<ControlBlock *>(context->ev.data);
auto thisPtr = cb->weakConn.lock();
if (!thisPtr)
{
// TODO?
LOG_ERROR << "RedisConnection destruct unexpectedly!";
return;
}
if (status != REDIS_OK)
{
LOG_ERROR << "Failed to connect to "
Expand Down Expand Up @@ -224,14 +255,21 @@ void RedisConnection::startConnectionInLoop()
});
redisAsyncSetDisconnectCallback(
redisContext_, [](const redisAsyncContext *context, int /*status*/) {
auto thisPtr = static_cast<RedisConnection *>(context->ev.data);

auto *cb = static_cast<ControlBlock *>(context->ev.data);
auto thisPtr = cb->weakConn.lock();
if (!thisPtr)
{
LOG_TRACE << "Disconnected from " << cb->addr.toIpPort()
<< ", no more reconnect because RedisConnection has "
"been destructed";
delete cb;
return;
}
thisPtr->handleDisconnect();
if (thisPtr->disconnectCallback_)
{
thisPtr->disconnectCallback_(thisPtr->shared_from_this());
}

LOG_TRACE << "Disconnected from "
<< thisPtr->serverAddr_.toIpPort();
});
Expand Down Expand Up @@ -260,35 +298,44 @@ void RedisConnection::handleDisconnect()
redisContext_->ev.addRead = nullptr;
redisContext_->ev.delRead = nullptr;
redisContext_->ev.cleanup = nullptr;
delete (ControlBlock *)redisContext_->ev.data;
redisContext_->ev.data = nullptr;
}
void RedisConnection::addWrite(void *userData)
{
auto thisPtr = static_cast<RedisConnection *>(userData);
assert(thisPtr->channel_);
thisPtr->channel_->enableWriting();
auto *cb = static_cast<ControlBlock *>(userData);
assert(cb->channel);
cb->channel->enableWriting();
}
void RedisConnection::delWrite(void *userData)
{
auto thisPtr = static_cast<RedisConnection *>(userData);
assert(thisPtr->channel_);
thisPtr->channel_->disableWriting();
auto *cb = static_cast<ControlBlock *>(userData);
assert(cb->channel);
cb->channel->disableWriting();
}
void RedisConnection::addRead(void *userData)
{
auto thisPtr = static_cast<RedisConnection *>(userData);
assert(thisPtr->channel_);
thisPtr->channel_->enableReading();
auto *cb = static_cast<ControlBlock *>(userData);
assert(cb->channel);
cb->channel->enableReading();
}
void RedisConnection::delRead(void *userData)
{
auto thisPtr = static_cast<RedisConnection *>(userData);
assert(thisPtr->channel_);
thisPtr->channel_->disableReading();
auto *cb = static_cast<ControlBlock *>(userData);
assert(cb->channel);
cb->channel->disableReading();
}
void RedisConnection::cleanup(void * /*userData*/)
void RedisConnection::cleanup(void *userData)
{
LOG_TRACE << "cleanup";
LOG_TRACE << "RedisConnection::cleanup";
auto *cb = static_cast<ControlBlock *>(userData);
if (cb)
{
assert(cb->channel);
cb->channel->remove();
// cleanup() should only update socket status
// should not delete cb here
}
}

void RedisConnection::handleRedisRead()
Expand Down Expand Up @@ -318,8 +365,12 @@ void RedisConnection::sendCommandInLoop(
redisAsyncFormattedCommand(
redisContext_,
[](redisAsyncContext *context, void *r, void * /*userData*/) {
auto thisPtr = static_cast<RedisConnection *>(context->ev.data);
thisPtr->handleResult(static_cast<redisReply *>(r));
auto *cb = static_cast<ControlBlock *>(context->ev.data);
auto thisPtr = cb->weakConn.lock();
if (thisPtr)
{
thisPtr->handleResult(static_cast<redisReply *>(r));
}
},
nullptr,
command.c_str(),
Expand Down Expand Up @@ -405,10 +456,14 @@ void RedisConnection::sendSubscribeInLoop(
redisAsyncFormattedCommand(
redisContext_,
[](redisAsyncContext *context, void *r, void *subCtx) {
auto thisPtr = static_cast<RedisConnection *>(context->ev.data);
thisPtr->handleSubscribeResult(static_cast<redisReply *>(r),
static_cast<SubscribeContext *>(
subCtx));
auto *cb = static_cast<ControlBlock *>(context->ev.data);
auto thisPtr = cb->weakConn.lock();
if (thisPtr)
{
thisPtr->handleSubscribeResult(static_cast<redisReply *>(r),
static_cast<SubscribeContext *>(
subCtx));
}
},
subCtx.get(),
subCtx->subscribeCommand().c_str(),
Expand All @@ -427,10 +482,14 @@ void RedisConnection::sendUnsubscribeInLoop(
redisAsyncFormattedCommand(
redisContext_,
[](redisAsyncContext *context, void *r, void *subCtx) {
auto thisPtr = static_cast<RedisConnection *>(context->ev.data);
thisPtr->handleSubscribeResult(static_cast<redisReply *>(r),
static_cast<SubscribeContext *>(
subCtx));
auto *cb = static_cast<ControlBlock *>(context->ev.data);
auto thisPtr = cb->weakConn.lock();
if (thisPtr)
{
thisPtr->handleSubscribeResult(static_cast<redisReply *>(r),
static_cast<SubscribeContext *>(
subCtx));
}
},
subCtx.get(),
subCtx->unsubscribeCommand().c_str(),
Expand Down
9 changes: 2 additions & 7 deletions nosql_lib/redis/src/RedisConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RedisConnection : public trantor::NonCopyable,
const std::string &password,
unsigned int db,
trantor::EventLoop *loop);
~RedisConnection();
void setConnectCallback(
const std::function<void(std::shared_ptr<RedisConnection> &&)>
&callback)
Expand Down Expand Up @@ -149,12 +150,6 @@ class RedisConnection : public trantor::NonCopyable,
void sendSubscribe(const std::shared_ptr<SubscribeContext> &subCtx);
void sendUnsubscribe(const std::shared_ptr<SubscribeContext> &subCtx);

~RedisConnection()
{
LOG_TRACE << (int)status_;
if (redisContext_ && status_ != ConnectStatus::kEnd)
redisAsyncDisconnect(redisContext_);
}
void disconnect();
void sendCommand(RedisResultCallback &&resultCallback,
RedisExceptionCallback &&exceptionCallback,
Expand All @@ -181,7 +176,7 @@ class RedisConnection : public trantor::NonCopyable,
const std::string password_;
const unsigned int db_;
trantor::EventLoop *loop_{nullptr};
std::unique_ptr<trantor::Channel> channel_{nullptr};
std::shared_ptr<trantor::Channel> channel_{nullptr};
std::function<void(std::shared_ptr<RedisConnection> &&)> connectCallback_;
std::function<void(std::shared_ptr<RedisConnection> &&)>
disconnectCallback_;
Expand Down