Skip to content

Commit

Permalink
Refactor connect to information to its own class
Browse files Browse the repository at this point in the history
  • Loading branch information
Molter73 committed Sep 13, 2023
1 parent 91abee4 commit 045c4c7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 58 deletions.
90 changes: 40 additions & 50 deletions collector/lib/FileDownloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <algorithm>
#include <chrono>
#include <fstream>
#include <optional>
#include <string_view>
#include <unistd.h>
#include <utils.h>
Expand Down Expand Up @@ -94,7 +95,35 @@ int DebugCallback(CURL*, curl_infotype type, char* data, size_t size, void*) {

} // namespace

FileDownloader::FileDownloader() : connect_to_(nullptr) {
ConnectTo::ConnectTo(std::string_view host, std::string_view connect_to) : connect_to_(nullptr, curl_slist_free_all) {
auto marker = connect_to.find(':');
if (marker == std::string_view::npos) {
connect_to_host_ = connect_to;
connect_to_port_ = "";
} else {
connect_to_host_ = connect_to.substr(0, marker);
connect_to_port_ = connect_to.substr(marker + 1);
}

marker = host.find(':');
if (marker == std::string_view::npos) {
host_ = host;
port_ = connect_to_port_;
} else {
host_ = host.substr(0, marker);
port_ = host.substr(marker + 1);
}

std::string entry{host_ + ":" + port_ + ":" + connect_to_host_ + ":" + connect_to_port_};

connect_to_.reset(curl_slist_append(nullptr, entry.c_str()));
if (connect_to_ == nullptr) {
CLOG(WARNING) << "Failed to create connect_to list";
return;
}
}

FileDownloader::FileDownloader() : connect_to_(std::nullopt) {
curl_ = curl_easy_init();

if (curl_) {
Expand All @@ -117,7 +146,7 @@ FileDownloader::~FileDownloader() {

curl_global_cleanup();

curl_slist_free_all(connect_to_);
connect_to_.reset();
}

bool FileDownloader::SetURL(const char* const url) {
Expand Down Expand Up @@ -227,20 +256,10 @@ bool FileDownloader::Key(const char* const path) {
return true;
}

bool FileDownloader::ConnectTo(const std::string& entry) {
return ConnectTo(entry.c_str());
}

bool FileDownloader::ConnectTo(const char* const entry) {
curl_slist* temp = curl_slist_append(connect_to_, entry);

if (temp == nullptr) {
CLOG(WARNING) << "Unable to set option to connect to '" << entry;
return false;
bool FileDownloader::SetConnectTo(const std::string& host, const std::string& target) {
if (!host.empty() && !target.empty() && host != target) {
connect_to_ = ConnectTo(host, target);
}

connect_to_ = temp;

return true;
}

Expand All @@ -256,44 +275,15 @@ void FileDownloader::SetVerboseMode(bool verbose) {
std::string FileDownloader::GetEffectiveURL() {
// For the time being, we can only have a single connect_to_ object,
// if it's not set, the download will go to the set URL.
if (connect_to_ == nullptr || connect_to_->data == nullptr) {
if (!connect_to_) {
return GetURL();
}

// Format for the connect_to field is:
// HOST:PORT:CONNECT-TO-HOST:CONNECT-TO-PORT
std::string_view connect_to{connect_to_->data};

auto marker = connect_to.find(':');
if (marker == std::string_view::npos) {
CLOG(WARNING) << "Malformed connect_to_: " << connect_to_;
if (connect_to_->GetPort() != GetPort() || connect_to_->GetHost() != GetHost()) {
return GetURL();
}
auto host = connect_to.substr(0, marker);
connect_to.remove_prefix(marker + 1);

marker = connect_to.find(':');
if (marker == std::string_view::npos) {
CLOG(WARNING) << "Malformed connect_to_: " << connect_to_;
return GetURL();
}
auto port = connect_to.substr(0, marker);
connect_to.remove_prefix(marker + 1);

marker = connect_to.find(':');
if (marker == std::string_view::npos) {
CLOG(WARNING) << "Malformed connect_to_: " << connect_to_;
return GetURL();
}

auto connect_to_host = connect_to.substr(0, marker);
auto connect_to_port = connect_to.substr(marker + 1);

if (port != GetPort() || host != GetHost()) {
return GetURL();
}

return GetScheme() + "://" + std::string{connect_to_host} + ":" + std::string{connect_to_port} + GetPath();
return GetScheme() + "://" + connect_to_->GetConnectToHost() + ":" + connect_to_->GetConnectToPort() + GetPath();
}

void FileDownloader::ResetCURL() {
Expand All @@ -306,8 +296,7 @@ void FileDownloader::ResetCURL() {

SetDefaultOptions();

curl_slist_free_all(connect_to_);
connect_to_ = nullptr;
connect_to_ = std::nullopt;
}

bool FileDownloader::IsReady() {
Expand All @@ -326,10 +315,11 @@ bool FileDownloader::Download() {
}

if (connect_to_) {
auto result = curl_easy_setopt(curl_, CURLOPT_CONNECT_TO, connect_to_);
auto result = curl_easy_setopt(curl_, CURLOPT_CONNECT_TO, connect_to_->GetList());

if (result != CURLE_OK) {
CLOG(WARNING) << "Unable to set connection host, the download is likely to fail - " << curl_easy_strerror(result);
return false;
}
}

Expand Down
35 changes: 32 additions & 3 deletions collector/lib/FileDownloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

#include <array>
#include <chrono>
#include <memory>
#include <optional>
#include <ostream>
#include <string_view>
#include <utility>

#include <curl/curl.h>
#include <curl/urlapi.h>
Expand All @@ -16,6 +20,32 @@ struct DownloadData {
std::ostream* os;
};

class ConnectTo {
public:
ConnectTo() : connect_to_(nullptr, curl_slist_free_all) {}
ConnectTo(std::string_view host, std::string_view connect_to);
ConnectTo(ConnectTo&) = delete;
ConnectTo(ConnectTo&&) = default;
~ConnectTo() = default;

const curl_slist* GetList() { return connect_to_.get(); };

const std::string& GetHost() { return host_; };
const std::string& GetPort() { return port_; };
const std::string& GetConnectToHost() { return connect_to_host_; };
const std::string& GetConnectToPort() { return connect_to_port_; };

ConnectTo& operator=(const ConnectTo& other) = delete;
ConnectTo& operator=(ConnectTo&& other) = default;

private:
std::unique_ptr<curl_slist, void (*)(curl_slist*)> connect_to_;
std::string host_;
std::string port_;
std::string connect_to_host_;
std::string connect_to_port_;
};

/**
* Wrapper aroung libcurl for downloading files.
* See https://curl.se/libcurl/c/libcurl-easy.html for details about specific
Expand Down Expand Up @@ -43,8 +73,7 @@ class FileDownloader {
bool CACert(const char* const path);
bool Cert(const char* const path);
bool Key(const char* const path);
bool ConnectTo(const std::string& entry);
bool ConnectTo(const char* const entry);
bool SetConnectTo(const std::string& host, const std::string& target);
void SetVerboseMode(bool verbose);

std::string GetURL() { return GetURLPart(CURLUPART_URL); }
Expand All @@ -61,7 +90,7 @@ class FileDownloader {
private:
CURL* curl_;
CURLU* url_;
curl_slist* connect_to_;
std::optional<ConnectTo> connect_to_;
std::string output_path_;
std::string file_path_;
std::array<char, CURL_ERROR_SIZE> error_;
Expand Down
3 changes: 2 additions & 1 deletion collector/lib/GetKernelObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ bool DownloadKernelObjectFromHostname(FileDownloader& downloader, const Json::Va

std::string server_hostname;
if (hostname.compare(0, port_offset, SNI_hostname) != 0) {
downloader.SetConnectTo(SNI_hostname, hostname);

const std::string server_port(hostname.substr(port_offset + 1));
server_hostname = SNI_hostname + ":" + server_port;
downloader.ConnectTo(SNI_hostname + ":" + server_port + ":" + hostname);
} else {
server_hostname = hostname;
}
Expand Down
10 changes: 6 additions & 4 deletions collector/test/FileDownloaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,25 @@ TEST(FileDownloaderTest, EffectiveURLBasic) {

TEST(FileDownloaderTest, EffectiveURLConnectTo) {
std::string url = "https://sensor.stackrox.svc:443/some-file.o.gz";
std::string target = "sensor.stackrox.svc:443:sensor.rhacs-operator.svc:443";
std::string host = "sensor.stackrox.svc";
std::string connect_to = "sensor.rhacs-operator.svc:443";
std::string expected_url = "https://sensor.rhacs-operator.svc:443/some-file.o.gz";
FileDownloader fd;

fd.SetURL(url);
fd.ConnectTo(target);
fd.SetConnectTo(host, connect_to);

ASSERT_EQ(expected_url, fd.GetEffectiveURL());
}
TEST(FileDownloaderTest, EffectiveURLConnectToNoMatch) {
std::string url = "https://sensor.stackrox.svc:8443/some-file.o.gz";
std::string target = "sensor.stackrox.svc:443:sensor.rhacs-operator.svc:443";
std::string host = "sensor.stackrox.svc";
std::string connect_to = "sensor.rhacs-operator.svc:443";
std::string_view expected_url = url;
FileDownloader fd;

fd.SetURL(url);
fd.ConnectTo(target);
fd.SetConnectTo(host, connect_to);

ASSERT_EQ(expected_url, fd.GetEffectiveURL());
}
Expand Down

0 comments on commit 045c4c7

Please sign in to comment.