Skip to content

Commit

Permalink
Improve support to forward REMB packets to each stream (#1186)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcague committed Apr 6, 2018
1 parent 27ddc9b commit 99917f3
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 13 deletions.
5 changes: 5 additions & 0 deletions erizo/src/erizo/MediaStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ MediaStream::~MediaStream() {
ELOG_DEBUG("%s message: Destructor ended", toLog());
}

uint32_t MediaStream::getMaxVideoBW() {
uint32_t bitrate = rtcp_processor_ ? rtcp_processor_->getMaxVideoBW() : 0;
return bitrate;
}

void MediaStream::syncClose() {
ELOG_DEBUG("%s message:Close called", toLog());
if (!sending_) {
Expand Down
3 changes: 2 additions & 1 deletion erizo/src/erizo/MediaStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink,
virtual ~MediaStream();
bool init();
void close() override;
virtual uint32_t getMaxVideoBW();
void syncClose();
bool setRemoteSdp(std::shared_ptr<SdpInfo> sdp);
bool setLocalSdp(std::shared_ptr<SdpInfo> sdp);
Expand All @@ -84,7 +85,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink,

void getJSONStats(std::function<void(std::string)> callback);

void onTransportData(std::shared_ptr<DataPacket> packet, Transport *transport);
virtual void onTransportData(std::shared_ptr<DataPacket> packet, Transport *transport);

void sendPacketAsync(std::shared_ptr<DataPacket> packet);

Expand Down
53 changes: 43 additions & 10 deletions erizo/src/erizo/WebRtcConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void WebRtcConnection::syncClose() {
return;
}
sending_ = false;
media_streams_.clear();
if (video_transport_.get()) {
video_transport_->close();
}
Expand Down Expand Up @@ -470,22 +471,49 @@ void WebRtcConnection::onCandidate(const CandidateInfo& cand, Transport *transpo
}
}

void WebRtcConnection::onREMBFromTransport(RtcpHeader *chead, Transport *transport) {
std::vector<std::shared_ptr<MediaStream>> streams;

for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) {
uint32_t ssrc_feed = chead->getREMBFeedSSRC(index);
forEachMediaStream([ssrc_feed, &streams] (const std::shared_ptr<MediaStream> &media_stream) {
if (media_stream->isSinkSSRC(ssrc_feed)) {
streams.push_back(media_stream);
}
});
}

std::sort(streams.begin(), streams.end(),
[](const std::shared_ptr<MediaStream> &i, const std::shared_ptr<MediaStream> &j) {
return i->getMaxVideoBW() < j->getMaxVideoBW();
});

uint8_t remaining_streams = streams.size();
uint32_t remaining_bitrate = chead->getREMBBitRate();
std::for_each(streams.begin(), streams.end(),
[&remaining_bitrate, &remaining_streams, transport, chead](const std::shared_ptr<MediaStream> &stream) {
uint32_t max_bitrate = stream->getMaxVideoBW();
uint32_t remaining_avg_bitrate = remaining_bitrate / remaining_streams;
uint32_t bitrate = std::min(max_bitrate, remaining_avg_bitrate);
auto generated_remb = RtpUtils::createREMB(chead->getSSRC(), {stream->getVideoSinkSSRC()}, bitrate);
stream->onTransportData(generated_remb, transport);
remaining_bitrate -= bitrate;
remaining_streams--;
});
}

void WebRtcConnection::onRtcpFromTransport(std::shared_ptr<DataPacket> packet, Transport *transport) {
RtpUtils::forEachRtcpBlock(packet, [this, packet, transport](RtcpHeader *chead) {
uint32_t ssrc = chead->isFeedback() ? chead->getSourceSSRC() : chead->getSSRC();
if (chead->isREMB()) {
onREMBFromTransport(chead, transport);
return;
}
std::shared_ptr<DataPacket> rtcp = std::make_shared<DataPacket>(*packet);
rtcp->length = (ntohs(chead->length) + 1) * 4;
std::memcpy(rtcp->data, chead, rtcp->length);
forEachMediaStream([rtcp, transport, ssrc, chead] (const std::shared_ptr<MediaStream> &media_stream) {
if (chead->isREMB()) {
for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) {
uint32_t ssrc_feed = chead->getREMBFeedSSRC(index);
if (media_stream->isSourceSSRC(ssrc_feed) || media_stream->isSinkSSRC(ssrc_feed)) {
// TODO(javier): Calculate the portion of bitrate that corresponds to this stream.
media_stream->onTransportData(rtcp, transport);
}
}
} else if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) {
forEachMediaStream([rtcp, transport, ssrc] (const std::shared_ptr<MediaStream> &media_stream) {
if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) {
media_stream->onTransportData(rtcp, transport);
}
});
Expand Down Expand Up @@ -681,4 +709,9 @@ void WebRtcConnection::syncWrite(std::shared_ptr<DataPacket> packet) {
transport->write(packet->data, packet->length);
}

void WebRtcConnection::setTransport(std::shared_ptr<Transport> transport) { // Only for Testing purposes
video_transport_ = transport;
bundle_ = true;
}

} // namespace erizo
3 changes: 3 additions & 0 deletions erizo/src/erizo/WebRtcConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class WebRtcConnection: public TransportListener, public LogContext,
void forEachMediaStream(std::function<void(const std::shared_ptr<MediaStream>&)> func);
void forEachMediaStreamAsync(std::function<void(const std::shared_ptr<MediaStream>&)> func);

void setTransport(std::shared_ptr<Transport> transport); // Only for Testing purposes

std::shared_ptr<Stats> getStatsService() { return stats_; }

RtpExtensionProcessor& getRtpExtensionProcessor() { return extension_processor_; }
Expand All @@ -159,6 +161,7 @@ class WebRtcConnection: public TransportListener, public LogContext,
std::string getJSONCandidate(const std::string& mid, const std::string& sdp);
void trackTransportInfo();
void onRtcpFromTransport(std::shared_ptr<DataPacket> packet, Transport *transport);
void onREMBFromTransport(RtcpHeader *chead, Transport *transport);

private:
std::string connection_id_;
Expand Down
20 changes: 20 additions & 0 deletions erizo/src/erizo/rtp/RtpUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ std::shared_ptr<DataPacket> RtpUtils::createFIR(uint32_t source_ssrc, uint32_t s
return std::make_shared<DataPacket>(0, buf, len, VIDEO_PACKET);
}

std::shared_ptr<DataPacket> RtpUtils::createREMB(uint32_t ssrc, std::vector<uint32_t> ssrc_list, uint32_t bitrate) {
erizo::RtcpHeader remb;
remb.setPacketType(RTCP_PS_Feedback_PT);
remb.setBlockCount(RTCP_AFB);
memcpy(&remb.report.rembPacket.uniqueid, "REMB", 4);

remb.setSSRC(ssrc);
remb.setSourceSSRC(0);
remb.setLength(4 + ssrc_list.size());
remb.setREMBBitRate(bitrate);
remb.setREMBNumSSRC(ssrc_list.size());
uint8_t index = 0;
for (uint32_t feed_ssrc : ssrc_list) {
remb.setREMBFeedSSRC(index++, feed_ssrc);
}
int len = (remb.getLength() + 1) * 4;
char *buf = reinterpret_cast<char*>(&remb);
return std::make_shared<erizo::DataPacket>(0, buf, len, erizo::OTHER_PACKET);
}


int RtpUtils::getPaddingLength(std::shared_ptr<DataPacket> packet) {
RtpHeader *rtp_header = reinterpret_cast<RtpHeader*>(packet->data);
Expand Down
1 change: 1 addition & 0 deletions erizo/src/erizo/rtp/RtpUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RtpUtils {
static std::shared_ptr<DataPacket> createPLI(uint32_t source_ssrc, uint32_t sink_ssrc);

static std::shared_ptr<DataPacket> createFIR(uint32_t source_ssrc, uint32_t sink_ssrc, uint8_t seq_number);
static std::shared_ptr<DataPacket> createREMB(uint32_t ssrc, std::vector<uint32_t> ssrc_list, uint32_t bitrate);

static int getPaddingLength(std::shared_ptr<DataPacket> packet);

Expand Down
182 changes: 182 additions & 0 deletions erizo/src/test/WebRtcConnectionTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <rtp/RtpHeaders.h>
#include <rtp/RtpUtils.h>
#include <MediaDefinitions.h>
#include <WebRtcConnection.h>

#include <string>
#include <tuple>

#include "utils/Mocks.h"
#include "utils/Matchers.h"

using testing::_;
using testing::Return;
using testing::Eq;
using testing::Args;
using testing::AtLeast;
using erizo::DataPacket;
using erizo::ExtMap;
using erizo::IceConfig;
using erizo::RtpMap;
using erizo::RtpUtils;
using erizo::WebRtcConnection;

typedef std::vector<uint32_t> MaxList;
typedef std::vector<bool> EnabledList;
typedef std::vector<int32_t> ExpectedList;

class WebRtcConnectionTest :
public ::testing::TestWithParam<std::tr1::tuple<MaxList,
uint32_t,
EnabledList,
ExpectedList>> {
protected:
virtual void SetUp() {
index = 0;
simulated_clock = std::make_shared<erizo::SimulatedClock>();
simulated_worker = std::make_shared<erizo::SimulatedWorker>(simulated_clock);
simulated_worker->start();
io_worker = std::make_shared<erizo::IOWorker>();
io_worker->start();
connection = std::make_shared<WebRtcConnection>(simulated_worker, io_worker,
"test_connection", ice_config, rtp_maps, ext_maps, nullptr);
transport = std::make_shared<erizo::MockTransport>("test_connection", true, ice_config,
simulated_worker, io_worker);
connection->setTransport(transport);
connection->updateState(TRANSPORT_READY, transport.get());
max_video_bw_list = std::tr1::get<0>(GetParam());
bitrate_value = std::tr1::get<1>(GetParam());
add_to_remb_list = std::tr1::get<2>(GetParam());
expected_bitrates = std::tr1::get<3>(GetParam());

setUpStreams();
}

void setUpStreams() {
for (uint32_t max_video_bw : max_video_bw_list) {
streams.push_back(addMediaStream(false, max_video_bw));
}
}

std::shared_ptr<erizo::MockMediaStream> addMediaStream(bool is_publisher, uint32_t max_video_bw) {
std::string id = std::to_string(index);
std::string label = std::to_string(index);
uint32_t video_sink_ssrc = getSsrcFromIndex(index);
uint32_t audio_sink_ssrc = getSsrcFromIndex(index) + 1;
uint32_t video_source_ssrc = getSsrcFromIndex(index) + 2;
uint32_t audio_source_ssrc = getSsrcFromIndex(index) + 3;
auto media_stream = std::make_shared<erizo::MockMediaStream>(simulated_worker, connection, id, label,
rtp_maps, is_publisher);
media_stream->setVideoSinkSSRC(video_sink_ssrc);
media_stream->setAudioSinkSSRC(audio_sink_ssrc);
media_stream->setVideoSourceSSRC(video_source_ssrc);
media_stream->setAudioSourceSSRC(audio_source_ssrc);
connection->addMediaStream(media_stream);
simulated_worker->executeTasks();
EXPECT_CALL(*media_stream, getMaxVideoBW()).Times(AtLeast(0)).WillRepeatedly(Return(max_video_bw));
index++;
return media_stream;
}

void onRembReceived(uint32_t bitrate, std::vector<uint32_t> ids) {
std::transform(ids.begin(), ids.end(), ids.begin(), [](uint32_t id) {
return id * 1000;
});
auto remb = RtpUtils::createREMB(ids[0], ids, bitrate);
connection->onTransportData(remb, transport.get());
}

void onRembReceived() {
uint32_t index = 0;
std::vector<uint32_t> ids;
for (bool enabled : add_to_remb_list) {
if (enabled) {
ids.push_back(index);
}
index++;
}
onRembReceived(bitrate_value, ids);
}

uint32_t getIndexFromSsrc(uint32_t ssrc) {
return ssrc / 1000;
}

uint32_t getSsrcFromIndex(uint32_t index) {
return index * 1000;
}

virtual void TearDown() {
connection->close();
simulated_worker->executeTasks();
streams.clear();
}

std::vector<std::shared_ptr<erizo::MockMediaStream>> streams;
MaxList max_video_bw_list;
uint32_t bitrate_value;
EnabledList add_to_remb_list;
ExpectedList expected_bitrates;
IceConfig ice_config;
std::vector<RtpMap> rtp_maps;
std::vector<ExtMap> ext_maps;
uint32_t index;
std::shared_ptr<erizo::MockTransport> transport;
std::shared_ptr<WebRtcConnection> connection;
std::shared_ptr<erizo::MockRtcpProcessor> processor;
std::shared_ptr<erizo::SimulatedClock> simulated_clock;
std::shared_ptr<erizo::SimulatedWorker> simulated_worker;
std::shared_ptr<erizo::IOWorker> io_worker;
std::queue<std::shared_ptr<DataPacket>> packet_queue;
};

TEST_P(WebRtcConnectionTest, forwardRembToStreams_When_StreamTheyExist) {
uint32_t index = 0;
for (int32_t expected_bitrate : expected_bitrates) {
if (expected_bitrate > 0) {
EXPECT_CALL(*(streams[index]), onTransportData(_, _))
.With(Args<0>(erizo::RembHasBitrateValue(static_cast<uint32_t>(expected_bitrate)))).Times(1);
} else {
EXPECT_CALL(*streams[index], onTransportData(_, _)).Times(0);
}
index++;
}

onRembReceived();
}

INSTANTIATE_TEST_CASE_P(
REMB_values, WebRtcConnectionTest, testing::Values(
std::make_tuple(MaxList{300}, 100, EnabledList{1}, ExpectedList{100}),
std::make_tuple(MaxList{300}, 600, EnabledList{1}, ExpectedList{300}),

std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 0}, ExpectedList{300, -1}),
std::make_tuple(MaxList{300, 300}, 300, EnabledList{0, 1}, ExpectedList{-1, 300}),
std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 1}, ExpectedList{150, 150}),
std::make_tuple(MaxList{100, 300}, 300, EnabledList{1, 1}, ExpectedList{100, 200}),
std::make_tuple(MaxList{300, 100}, 300, EnabledList{1, 1}, ExpectedList{200, 100}),
std::make_tuple(MaxList{100, 100}, 300, EnabledList{1, 1}, ExpectedList{100, 100}),

std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 0, 0}, ExpectedList{300, -1, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{ -1, 300, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{150, 150, -1}),
std::make_tuple(MaxList{100, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 200, -1}),
std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{200, 100, -1}),
std::make_tuple(MaxList{100, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 100, -1}),

std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{-1, 300, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 0, 1}, ExpectedList{-1, -1, 300}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 150, 150}),
std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 200}),
std::make_tuple(MaxList{300, 300, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 200, 100}),
std::make_tuple(MaxList{300, 100, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 100}),

std::make_tuple(MaxList{100, 100, 100}, 300, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}),
std::make_tuple(MaxList{100, 100, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}),
std::make_tuple(MaxList{300, 300, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{200, 200, 200}),
std::make_tuple(MaxList{100, 200, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 200, 300}),
std::make_tuple(MaxList{300, 200, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{300, 200, 100}),
std::make_tuple(MaxList{100, 500, 500}, 800, EnabledList{1, 1, 1}, ExpectedList{100, 350, 350})));
34 changes: 32 additions & 2 deletions erizo/src/test/utils/Mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,33 @@ class MockMediaSink : public MediaSink {
}
};

class MockTransport: public Transport {
public:
MockTransport(std::string connection_id, bool bundle, const IceConfig &ice_config,
std::shared_ptr<Worker> worker, std::shared_ptr<IOWorker> io_worker) :
Transport(VIDEO_TYPE, "video", connection_id, bundle, true,
std::shared_ptr<erizo::TransportListener>(nullptr), ice_config,
worker, io_worker) {}

virtual ~MockTransport() {
}

void updateIceState(IceState state, IceConnection *conn) override {
}
void onIceData(packetPtr packet) override {
}
void onCandidate(const CandidateInfo &candidate, IceConnection *conn) override {
}
void write(char* data, int len) override {
}
void processLocalSdp(SdpInfo *localSdp_) override {
}
void start() override {
}
void close() override {
}
};

class MockWebRtcConnection: public WebRtcConnection {
public:
MockWebRtcConnection(std::shared_ptr<Worker> worker, std::shared_ptr<IOWorker> io_worker, const IceConfig &ice_config,
Expand All @@ -74,11 +101,14 @@ class MockMediaStream: public MediaStream {
public:
MockMediaStream(std::shared_ptr<Worker> worker, std::shared_ptr<WebRtcConnection> connection,
const std::string& media_stream_id, const std::string& media_stream_label,
std::vector<RtpMap> rtp_mappings) :
MediaStream(worker, connection, media_stream_id, media_stream_label, true) {
std::vector<RtpMap> rtp_mappings, bool is_publisher = true) :
MediaStream(worker, connection, media_stream_id, media_stream_label, is_publisher) {
local_sdp_ = std::make_shared<SdpInfo>(rtp_mappings);
remote_sdp_ = std::make_shared<SdpInfo>(rtp_mappings);
}

MOCK_METHOD0(getMaxVideoBW, uint32_t());
MOCK_METHOD2(onTransportData, void(std::shared_ptr<DataPacket>, Transport*));
};

class Reader : public InboundHandler {
Expand Down

0 comments on commit 99917f3

Please sign in to comment.