Skip to content

Commit

Permalink
- Isolated torch-related code
Browse files Browse the repository at this point in the history
- Removed componentDisabled
  • Loading branch information
Zhang committed Oct 11, 2024
1 parent 48657cb commit 2a7cbcb
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 71 deletions.
2 changes: 0 additions & 2 deletions src/colvar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,9 +894,7 @@ void colvar::define_component_types()

add_component_type<neuralNetwork>("neural network CV for other CVs", "neuralNetwork");

#ifdef COLVARS_TORCH
add_component_type<torchANN>("CV defined by PyTorch artifical neural network models", "torchANN");
#endif

if (proxy->check_volmaps_available() == COLVARS_OK) {
add_component_type<map_total>("total value of atomic map", "mapTotal");
Expand Down
3 changes: 0 additions & 3 deletions src/colvar.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,6 @@ class colvar : public colvarparse, public colvardeps {
// collective variable component base class
class cvc;

// placeholder/stub for unavailable functionality
class componentDisabled;

// list of available collective variable components

// scalar colvar components
Expand Down
17 changes: 2 additions & 15 deletions src/colvarcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ int colvar::cvc::init(std::string const &conf)
if (period != 0.0) {
if (!is_available(f_cvc_periodic)) {
error_code |=
cvm::error("Error: invalid use of period and/or wrapAround in a \"" +
cvm::error("Error: invalid use of period and/or "
"wrapAround in a \"" +
function_type() + "\" component.\n" + "Period: " + cvm::to_str(period) +
" wrapAround: " + cvm::to_str(wrap_center),
COLVARS_INPUT_ERROR);
Expand Down Expand Up @@ -706,20 +707,6 @@ void colvar::cvc::wrap(colvarvalue &x_unwrapped) const
}



colvar::componentDisabled::componentDisabled() {}

colvar::componentDisabled::~componentDisabled() {}

int colvar::componentDisabled::init(std::string const & /* conf */)
{
return cvm::error("Error: components of type " + function_type() +
" are not enabled in the current build",
COLVARS_NOT_IMPLEMENTED);
}



// Static members

std::vector<colvardeps::feature *> colvar::cvc::cvc_features;
49 changes: 0 additions & 49 deletions src/colvarcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
#include "colvar.h"
#include "colvar_geometricpath.h"

#ifdef COLVARS_TORCH
#include <torch/torch.h>
#include <torch/script.h>
#endif


/// \brief Colvar component (base class for collective variables)
///
Expand Down Expand Up @@ -320,16 +315,6 @@ inline colvarvalue const & colvar::cvc::Jacobian_derivative() const
}


/// \brief Colvar component class for a feature not currently available
class colvar::componentDisabled
: public colvar::cvc
{
public:
componentDisabled();
virtual ~componentDisabled();
int init(std::string const & /* conf */);
};


/// \brief Colvar component: distance between the centers of mass of
/// two groups (colvarvalue::type_scalar type, range [0:*))
Expand Down Expand Up @@ -1547,40 +1532,6 @@ class colvar::neuralNetwork
virtual void wrap(colvarvalue &x_unwrapped) const;
};

#ifdef COLVARS_TORCH
// only when LibTorch is available
class colvar::torchANN
: public colvar::linearCombination
{
protected:
torch::jit::script::Module nn;
/// the index of nn output component
size_t m_output_index;
bool use_double_input;
//bool use_gpu;
// 1d tensor, concatenation of values of sub-cvcs
torch::Tensor input_tensor;
torch::Tensor nn_outputs;
torch::Tensor input_grad;
// record the initial index of of sub-cvcs in input_tensor
std::vector<int> cvc_indices;
public:
torchANN();
virtual ~torchANN();
virtual int init(std::string const &conf);
virtual void calc_value();
virtual void calc_gradients();
virtual void apply_force(colvarvalue const &force);
};
#else
class colvar::torchANN
: public colvar::componentDisabled
{
public:
torchANN();
virtual ~torchANN();
};
#endif // COLVARS_TORCH checking

// \brief Colvar component: total value of a scalar map
// (usually implemented as a grid by the simulation engine)
Expand Down
12 changes: 10 additions & 2 deletions src/colvarcomp_torchann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "colvarparse.h"
#include "colvarvalue.h"

#include "colvarcomp_torchann.h"

#ifdef COLVARS_TORCH

Expand Down Expand Up @@ -202,14 +203,21 @@ void colvar::torchANN::apply_force(colvarvalue const &force) {

#else


colvar::torchANN::torchANN()
{
set_function_type("torchANN");
}


colvar::torchANN::~torchANN() {}

int colvar::torchANN::init(std::string const &conf) {

return cvm::error(
"torchANN requires the libtorch library, but it is not enabled during compilation.\n"
"Please refer to the Compilation Notes section of the Colvars manual for more "
"information.\n",
COLVARS_NOT_IMPLEMENTED);

}

#endif
63 changes: 63 additions & 0 deletions src/colvarcomp_torchann.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// -*- c++ -*-

// This file is part of the Collective Variables module (Colvars).
// The original version of Colvars and its updates are located at:
// https://github.com/Colvars/colvars
// Please update all Colvars source files before making any changes.
// If you wish to distribute your changes, please submit them to the
// Colvars repository at GitHub.
//
#ifndef COLVARCOMP_TORCH_H
#define COLVARCOMP_TORCH_H

// Declaration of torchann

#include <memory>

#include "colvar.h"
#include "colvarcomp.h"
#include "colvarmodule.h"

#ifdef COLVARS_TORCH

#include <torch/torch.h>
#include <torch/script.h>

class colvar::torchANN
: public colvar::linearCombination
{
protected:
torch::jit::script::Module nn;
/// the index of nn output component
size_t m_output_index;
bool use_double_input;
//bool use_gpu;
// 1d tensor, concatenation of values of sub-cvcs
torch::Tensor input_tensor;
torch::Tensor nn_outputs;
torch::Tensor input_grad;
// record the initial index of of sub-cvcs in input_tensor
std::vector<int> cvc_indices;
public:
torchANN();
virtual ~torchANN();
virtual int init(std::string const &conf);
virtual void calc_value();
virtual void calc_gradients();
virtual void apply_force(colvarvalue const &force);
};

#else

class colvar::torchANN
: public colvar::cvc
{
public:
torchANN();
virtual ~torchANN();
virtual int init(std::string const &conf);
};
#endif // COLVARS_TORCH checking

#endif

0 comments on commit 2a7cbcb

Please sign in to comment.