Skip to content

Commit

Permalink
Merge pull request #375 from WeiPhil/register_loss_fix
Browse files Browse the repository at this point in the history
Added interface to register loss from outside library
  • Loading branch information
Tom94 authored Oct 13, 2023
2 parents bd1f727 + c305956 commit 2121041
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/tiny-cuda-nn/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,7 @@ std::unique_ptr<Loss<T>> default_loss(const std::string& name) {

std::vector<std::string> builtin_losses();

template <typename T>
void register_loss(const std::string& name, const std::function<Loss<T>*(const json&)>& factory);

}
3 changes: 3 additions & 0 deletions src/loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ void register_loss(const std::string& name, const std::function<Loss<T>*(const j
register_loss(loss_factories<T>(), name, factory);
}

template void register_loss<float>(const std::string& name, const std::function<Loss<float>*(const json&)>& factory);
template void register_loss<__half>(const std::string& name, const std::function<Loss<__half>*(const json&)>& factory);

template <typename T>
Loss<T>* create_loss(const json& loss) {
std::string name = loss.value("otype", "RelativeL2");
Expand Down

0 comments on commit 2121041

Please sign in to comment.