Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NAM: Implement direct media requests fallback #576

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions Quotient/mxcreply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ class Q_DECL_HIDDEN MxcReply::Private

MxcReply::MxcReply(QNetworkReply* reply,
const EncryptedFileMetadata& fileMetadata)
: d(makeImpl<Private>(reply, fileMetadata.isValid() ? nullptr : reply))
{
reply->setParent(this);
setNetworkReply(reply, fileMetadata);
}

void MxcReply::setNetworkReply(QNetworkReply* reply,
const EncryptedFileMetadata& fileMetadata)
{
d = makeImpl<Private>(reply, fileMetadata.isValid() ? nullptr : reply);
d->m_reply->setParent(this);
connect(d->m_reply, &QNetworkReply::finished, this, [this, fileMetadata] {
setError(d->m_reply->error(), d->m_reply->errorString());

Expand All @@ -34,16 +40,18 @@ MxcReply::MxcReply(QNetworkReply* reply,
d->m_device = buffer;
}
#endif
setFinished(true);
setOpenMode(ReadOnly);
emit finished();
});
}

MxcReply::MxcReply()
: d(ZeroImpl<Private>())
MxcReply::MxcReply(DeferredFlag) {}

MxcReply::MxcReply(ErrorFlag)
{
static const auto BadRequestPhrase = tr("Bad Request");
QMetaObject::invokeMethod(this, [this]() {
QMetaObject::invokeMethod(this, [this] {
setAttribute(QNetworkRequest::HttpStatusCodeAttribute, 400);
setAttribute(QNetworkRequest::HttpReasonPhraseAttribute,
BadRequestPhrase);
Expand All @@ -55,7 +63,7 @@ MxcReply::MxcReply()
}, Qt::QueuedConnection);
}

qint64 MxcReply::readData(char *data, qint64 maxSize)
qint64 MxcReply::readData(char* data, qint64 maxSize)
{
if(d != nullptr && d->m_device != nullptr) {
return d->m_device->read(data, maxSize);
Expand Down
13 changes: 10 additions & 3 deletions Quotient/mxcreply.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ class QUOTIENT_API MxcReply : public QNetworkReply
{
Q_OBJECT
public:
explicit MxcReply();
enum DeferredFlag { Deferred };
enum ErrorFlag { Error };

explicit MxcReply(QNetworkReply* reply,
const EncryptedFileMetadata& fileMetadata);
explicit MxcReply(DeferredFlag);
explicit MxcReply(ErrorFlag);

void setNetworkReply(QNetworkReply* newReply,
const EncryptedFileMetadata& fileMetadata = {});

qint64 bytesAvailable() const override;

Expand All @@ -26,6 +33,6 @@ public Q_SLOTS:

private:
class Private;
ImplPtr<Private> d;
ImplPtr<Private> d = ZeroImpl<Private>();
};
}
} // namespace Quotient
87 changes: 72 additions & 15 deletions Quotient/networkaccessmanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "logging.h"
#include "mxcreply.h"
#include "connection.h"

#include "events/filesourceinfo.h"

Expand All @@ -20,6 +21,10 @@
using namespace Quotient;

namespace {

static constexpr auto DirectMediaRequestsSetting =
"Network/allow_direct_media_requests"_ls;

class {
public:
void addBaseUrl(const QString& accountId, const QUrl& baseUrl)
Expand All @@ -46,15 +51,42 @@ class {
{
return QReadLocker{ &namLock }, ignoredSslErrors;
}
void allowDirectMediaRequests(bool allow)
{
if (allow)
directMediaRequestsAreAllowed.test_and_set();
else
directMediaRequestsAreAllowed.clear();
}
bool directMediaRequestsAllowed() const
{
return directMediaRequestsAreAllowed.test();
}

private:
mutable QReadWriteLock namLock{};
QHash<QString, QUrl> baseUrls{};
QList<QSslError> ignoredSslErrors{};
// This one is small enough to be atomic and not need a read-write lock
std::atomic_flag directMediaRequestsAreAllowed{};
} d;

std::once_flag directMediaRequestsInitFlag;

} // anonymous namespace

void NetworkAccessManager::allowDirectMediaRequests(bool allow, bool permanent)
{
d.allowDirectMediaRequests(allow);
if (permanent)
QSettings().setValue(DirectMediaRequestsSetting, allow);
}

bool NetworkAccessManager::directMediaRequestsAllowed()
{
return d.directMediaRequestsAllowed();
}

void NetworkAccessManager::addBaseUrl(const QString& accountId,
const QUrl& homeserver)
{
Expand Down Expand Up @@ -84,6 +116,11 @@ void NetworkAccessManager::clearIgnoredSslErrors()

NetworkAccessManager* NetworkAccessManager::instance()
{
// Initialise direct media requests allowance at the very first NAM creation
std::call_once(directMediaRequestsInitFlag, [] {
NetworkAccessManager::allowDirectMediaRequests(
QSettings().value(DirectMediaRequestsSetting).toBool());
});
thread_local auto* nam = [] {
auto* namInit = new NetworkAccessManager();
connect(QThread::currentThread(), &QThread::finished, namInit,
Expand All @@ -103,38 +140,58 @@ QNetworkReply* NetworkAccessManager::createRequest(
reply->ignoreSslErrors(d.getIgnoredSslErrors());
return reply;
}
Q_ASSERT(!url.isRelative());

const auto createImplRequest = [this, op, request, outgoingData,
url](const QUrl& baseUrl) {
QNetworkRequest rewrittenRequest(request);
rewrittenRequest.setUrl(DownloadFileJob::makeRequestUrl(baseUrl, url));
auto* implReply = QNetworkAccessManager::createRequest(op,
rewrittenRequest,
outgoingData);
implReply->ignoreSslErrors(d.getIgnoredSslErrors());
return implReply;
};

const QUrlQuery query{ url.query() };
const auto accountId = query.queryItemValue(QStringLiteral("user_id"));
if (accountId.isEmpty()) {
// Using QSettings here because Quotient::NetworkSettings
// doesn't provide multi-threading guarantees
if (static thread_local const QSettings s;
s.value("Network/allow_direct_media_requests"_ls).toBool()) //
{
// TODO: Make the best effort with a direct unauthenticated request
// to the media server
qCWarning(NETWORK)
<< "Direct unauthenticated mxc requests are not implemented";
return new MxcReply();
if (directMediaRequestsAllowed()) {
// Best effort with an unauthenticated request directly to the media
// homeserver (rather than via own homeserver)
auto* mxcReply = new MxcReply(MxcReply::Deferred);
// Connection class is, by the moment of this call, reentrant (it
// is not early on when user/room object factories and E2EE are set;
// but if you have an mxc link you are already well past that, most
// likely) so we can create and use it here, even if a connection
// to the same homeserver exists already.
auto* c = new Connection(mxcReply);
connect(c, &Connection::homeserverChanged, mxcReply,
[mxcReply, createImplRequest, c](const QUrl& baseUrl) {
mxcReply->setNetworkReply(createImplRequest(baseUrl));
c->deleteLater();
});
// Hack up a minimum "viable" MXID on the target homeserver
// to satisfy resolveServer()
c->resolveServer("@:"_ls % request.url().host());
return mxcReply;
}
qCWarning(NETWORK)
<< "No connection specified, cannot convert mxc request";
return new MxcReply();
return new MxcReply(MxcReply::Error);
}
const auto& baseUrl = d.getBaseUrl(accountId);
if (!baseUrl.isValid()) {
// Strictly speaking, it should be an assert...
qCCritical(NETWORK) << "Homeserver for" << accountId
<< "not found, cannot convert mxc request";
return new MxcReply();
return new MxcReply(MxcReply::Error);
}

// Convert mxc:// URL into normal http(s) for the given homeserver
QNetworkRequest rewrittenRequest(request);
rewrittenRequest.setUrl(DownloadFileJob::makeRequestUrl(baseUrl, url));

auto* implReply = QNetworkAccessManager::createRequest(op, rewrittenRequest);
implReply->ignoreSslErrors(d.getIgnoredSslErrors());
auto* implReply = createImplRequest(baseUrl);
const auto& fileMetadata = FileMetadataMap::lookup(
query.queryItemValue(QStringLiteral("room_id")),
query.queryItemValue(QStringLiteral("event_id")));
Expand Down
4 changes: 3 additions & 1 deletion Quotient/networkaccessmanager.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace Quotient {
class QUOTIENT_API NetworkAccessManager : public QNetworkAccessManager {
Q_OBJECT
public:
using QNetworkAccessManager::QNetworkAccessManager;
static void allowDirectMediaRequests(bool allow = true,
bool permanent = true);
static bool directMediaRequestsAllowed();

static void addBaseUrl(const QString& accountId, const QUrl& homeserver);
static void dropBaseUrl(const QString& accountId);
Expand Down
74 changes: 53 additions & 21 deletions quotest/quotest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ private slots:
[[nodiscard]] bool checkFileSendingOutcome(const TestToken& thisTest,
const QString& txnId,
const QString& fileName);
[[nodiscard]] bool testDownload(const TestToken& thisTest,
const QUrl& mxcUrl);
[[nodiscard]] bool checkRedactionOutcome(const QByteArray& thisTest,
const QString& evtIdToRedact);

Expand Down Expand Up @@ -408,14 +410,16 @@ TEST_IMPL(sendReaction)
return false;
}

static constexpr auto fileContent = "Test";

TEST_IMPL(sendFile)
{
auto* tf = new QTemporaryFile;
if (!tf->open()) {
clog << "Failed to create a temporary file" << endl;
FAIL_TEST();
}
tf->write("Test");
tf->write(fileContent);
tf->close();
QFileInfo tfi { *tf };
// QFileInfo::fileName brings only the file name; QFile::fileName brings
Expand Down Expand Up @@ -462,30 +466,57 @@ struct DownloadRunner {

using result_type = QNetworkReply::NetworkError;

QNetworkReply::NetworkError operator()(int) const
result_type operator()(int) const
{
QEventLoop el;
QScopedPointer<QNetworkReply, QScopedPointerDeleteLater> reply {
const QScopedPointer<QNetworkReply, QScopedPointerDeleteLater> reply {
NetworkAccessManager::instance()->get(QNetworkRequest(url))
};
QObject::connect(
reply.data(), &QNetworkReply::finished, &el, [&el] { el.exit(); },
Qt::QueuedConnection);
el.exec();
return reply->error();
return reply->error() != QNetworkReply::NoError ? reply->error()
: reply->readAll() != fileContent
? QNetworkReply::UnknownContentError
: QNetworkReply::NoError;
}

static QVector<result_type> run(const QUrl& url, int threads)
{
return QtConcurrent::blockingMapped(QVector<int>(threads),
DownloadRunner{ url });
}
};

bool testDownload(const QUrl& url)
bool TestSuite::testDownload(const TestToken& thisTest, const QUrl& mxcUrl)
{
// Move out actual test from the multithreaded code
// to help debugging
auto results = QtConcurrent::blockingMapped(QVector<int> { 1, 2, 3 },
DownloadRunner { url });
return std::all_of(results.cbegin(), results.cend(),
[](QNetworkReply::NetworkError ne) {
return ne == QNetworkReply::NoError;
});
// Testing direct media requests needs explicit allowance
NetworkAccessManager::allowDirectMediaRequests(true, false);
if (const auto result = DownloadRunner::run(mxcUrl, 1);
result.front() != QNetworkReply::NoError) {
clog << "Direct media request to "
<< mxcUrl.toDisplayString().toStdString()
<< " was allowed but failed" << endl;
FAIL_TEST();
}
NetworkAccessManager::allowDirectMediaRequests(false, false);
if (const auto result = DownloadRunner::run(mxcUrl, 1);
result.front() == QNetworkReply::NoError) {
clog << "Direct media request to "
<< mxcUrl.toDisplayString().toStdString()
<< " was disallowed but succeeded" << endl;
FAIL_TEST();
}

static constexpr auto ThreadsCount = 3;
const auto httpUrl = targetRoom->connection()->makeMediaUrl(mxcUrl);
const auto results = DownloadRunner::run(httpUrl, ThreadsCount);
// Move out actual test from the multithreaded code to help debugging
// NB: remove explicit template argument once entirely at Qt 6 or C++23
FINISH_TEST(results
== QVector<QNetworkReply::NetworkError>(ThreadsCount,
QNetworkReply::NoError));
}

bool TestSuite::checkFileSendingOutcome(const TestToken& thisTest,
Expand Down Expand Up @@ -519,14 +550,15 @@ bool TestSuite::checkFileSendingOutcome(const TestToken& thisTest,
*evt,
[&](const RoomMessageEvent& e) {
// TODO: check #366 once #368 is implemented
FINISH_TEST(
!e.id().isEmpty()
&& pendingEvents[size_t(pendingIdx)]->transactionId()
== txnId
&& e.hasFileContent()
&& e.content()->fileInfo()->originalName == fileName
&& testDownload(targetRoom->connection()->makeMediaUrl(
e.content()->fileInfo()->url())));
if (e.id().isEmpty()
|| pendingEvents[size_t(pendingIdx)]->transactionId()
!= txnId
|| !e.hasFileContent()
|| e.content()->fileInfo()->originalName != fileName) {
clog << "Malformed file event";
FAIL_TEST();
}
return testDownload(thisTest, e.content()->fileInfo()->url());
},
[this, thisTest](const RoomEvent&) { FAIL_TEST(); });
});
Expand Down