Skip to content

Commit

Permalink
[dawn][native] Fixes TSAN race from GetCompilationInfo.
Browse files Browse the repository at this point in the history
- The race happens when multiple threads Complete the
  CompilationInfoEvent. The call to
  OwnedCompilationMessages::GetCompilationInfo was writing
  to the owned copy of CompilationInfo in both threads. Note
  that in practice, the competing threads would have written
  the same things.
- Adds an atomic bool to gate writing to the internal
  CompilationInfo struct that should only ever be written
  once.
- Updates some of the asserts to use the new bool to be
  explicit.

Bug: 373845830
Change-Id: I3998a276bc210bc5bf5bccfe31c44b4e01de4472
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/211947
Reviewed-by: Shrek Shao <[email protected]>
Auto-Submit: Loko Kung <[email protected]>
Reviewed-by: Geoff Lang <[email protected]>
Commit-Queue: Geoff Lang <[email protected]>
  • Loading branch information
lokokung authored and Dawn LUCI CQ committed Nov 1, 2024
1 parent 0eee090 commit aa806c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
24 changes: 13 additions & 11 deletions src/dawn/native/CompilationMessages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ ResultOrError<uint64_t> CountUTF16CodeUnitsFromUTF8String(const std::string_view
return numberOfUTF16CodeUnits;
}

OwnedCompilationMessages::OwnedCompilationMessages() {
mCompilationInfo.nextInChain = 0;
mCompilationInfo.messageCount = 0;
mCompilationInfo.messages = nullptr;
}
OwnedCompilationMessages::OwnedCompilationMessages() = default;

OwnedCompilationMessages::~OwnedCompilationMessages() = default;

Expand Down Expand Up @@ -193,7 +189,7 @@ MaybeError OwnedCompilationMessages::AddMessage(const tint::diag::Diagnostic& di

void OwnedCompilationMessages::AddMessage(const CompilationMessage& message) {
// Cannot add messages after GetCompilationInfo has been called.
DAWN_ASSERT(mCompilationInfo.messages == nullptr);
DAWN_ASSERT(!mCompilationInfo->has_value());

DAWN_ASSERT(message.nextInChain == nullptr);

Expand All @@ -208,7 +204,7 @@ void OwnedCompilationMessages::AddMessage(const CompilationMessage& message) {

MaybeError OwnedCompilationMessages::AddMessages(const tint::diag::List& diagnostics) {
// Cannot add messages after GetCompilationInfo has been called.
DAWN_ASSERT(mCompilationInfo.messages == nullptr);
DAWN_ASSERT(!mCompilationInfo->has_value());

for (const auto& diag : diagnostics) {
DAWN_TRY(AddMessage(diag));
Expand All @@ -221,17 +217,23 @@ MaybeError OwnedCompilationMessages::AddMessages(const tint::diag::List& diagnos

void OwnedCompilationMessages::ClearMessages() {
// Cannot clear messages after GetCompilationInfo has been called.
DAWN_ASSERT(mCompilationInfo.messages == nullptr);
DAWN_ASSERT(!mCompilationInfo->has_value());

mMessageStrings.clear();
mMessages.clear();
}

const CompilationInfo* OwnedCompilationMessages::GetCompilationInfo() {
mCompilationInfo.messageCount = mMessages.size();
mCompilationInfo.messages = mMessages.data();
return mCompilationInfo.Use([&](auto info) {
if (info->has_value()) {
return &info->value();
}

return &mCompilationInfo;
(*info).emplace();
(*info)->messageCount = mMessages.size();
(*info)->messages = mMessages.data();
return &info->value();
});
}

const std::vector<std::string>& OwnedCompilationMessages::GetFormattedTintMessages() const {
Expand Down
6 changes: 3 additions & 3 deletions src/dawn/native/CompilationMessages.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
#include <string>
#include <vector>

#include "dawn/common/MutexProtected.h"
#include "dawn/common/NonCopyable.h"
#include "dawn/native/Error.h"
#include "dawn/native/dawn_platform.h"

#include "dawn/common/NonCopyable.h"

namespace tint::diag {
class Diagnostic;
class List;
Expand Down Expand Up @@ -76,7 +76,7 @@ class OwnedCompilationMessages : public NonCopyable {
void AddMessage(const CompilationMessage& message);
void AddFormattedTintMessages(const tint::diag::List& diagnostics);

CompilationInfo mCompilationInfo;
MutexProtected<std::optional<CompilationInfo>> mCompilationInfo = std::nullopt;
std::vector<std::unique_ptr<std::string>> mMessageStrings;
std::vector<CompilationMessage> mMessages;
std::vector<std::string> mFormattedTintMessages;
Expand Down

0 comments on commit aa806c5

Please sign in to comment.