Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvement of use_softmax and zero_infinity #180

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ struct ctcOptions {

/// the label value/index that the CTC calculation should use as the blank label
int blank_label;

/// indicate whether to apply softmax on the input first.
bool use_softmax = true;

/// indicate whether to zero infinite losses and the associated gradients.
bool zero_infinity = false;
};

/** Compute the connectionist temporal classification loss between
Expand Down
68 changes: 60 additions & 8 deletions include/detail/cpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ class CpuCTC {
public:
// Noncopyable
CpuCTC(int alphabet_size, int minibatch, void* workspace, int num_threads,
int blank_label) :
int blank_label, bool use_softmax, bool zero_infinity) :
alphabet_size_(alphabet_size), minibatch_(minibatch),
num_threads_(num_threads), workspace_(workspace),
blank_label_(blank_label) {
blank_label_(blank_label), use_softmax_(use_softmax),
zero_infinity_(zero_infinity) {
#if defined(CTC_DISABLE_OMP) || defined(APPLE)
#else
if (num_threads > 0) {
Expand Down Expand Up @@ -75,10 +76,15 @@ class CpuCTC {
int num_threads_;
int blank_label_;
void* workspace_;
bool use_softmax_;
bool zero_infinity_;

void softmax(const ProbT* const activations, ProbT* probs,
const int* const input_lengths);

void exp(const ProbT* const activations, ProbT* probs,
const int* const input_lengths);

std::tuple<ProbT, bool>
cost_and_grad_kernel(ProbT *grad, const ProbT* const probs,
const int* const labels, int T, int L,
Expand Down Expand Up @@ -181,14 +187,34 @@ CpuCTC<ProbT>::softmax(const ProbT* const activations, ProbT* probs,

for(int r = 0; r < alphabet_size_; ++r) {
probs[r + col_offset] /= denom;
if (probs[r + col_offset] < min_T) {
probs[r + col_offset] = min_T;
}
// if (probs[r + col_offset] < min_T) {
// probs[r + col_offset] = min_T;
// }
}
}
}
}

template<typename ProbT>
void
CpuCTC<ProbT>::exp(const ProbT* const activations, ProbT* probs,
const int* const input_lengths) {
ProbT min_T = std::numeric_limits<ProbT>::min();

#pragma omp parallel for
for (int mb = 0; mb < minibatch_; ++mb) {
for(int c = 0; c < input_lengths[mb]; ++c) {
int col_offset = (mb + minibatch_ * c) * alphabet_size_;
for(int r = 0; r < alphabet_size_; ++r) {
probs[r + col_offset] = std::exp(activations[r + col_offset]);
// if (probs[r + col_offset] < min_T) {
// probs[r + col_offset] = min_T;
// }
}
}
}
}

template<typename ProbT>
std::tuple<ProbT, bool>
CpuCTC<ProbT>::cost_and_grad_kernel(ProbT *grad, const ProbT* const probs,
Expand Down Expand Up @@ -417,7 +443,12 @@ CpuCTC<ProbT>::cost_and_grad(const ProbT* const activations,
//labels w/blanks, e_inc, s_inc
per_minibatch_bytes += 3 * sizeof(int) * maxS;

softmax(activations, probs, input_lengths);
if (use_softmax_) {
softmax(activations, probs, input_lengths);
} else {
// since the later computation use log probabilities, here we exp the probs first and then log it will lead to probs itself
exp(activations, probs, input_lengths);
}

#pragma omp parallel for
for (int mb = 0; mb < minibatch_; ++mb) {
Expand All @@ -432,6 +463,18 @@ CpuCTC<ProbT>::cost_and_grad(const ProbT* const activations,
flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0),
T, L, mb,
bytes_used + mb * per_minibatch_bytes);
if (zero_infinity_) {
if (costs[mb] == -ctc_helper::neg_inf<ProbT>()) {
for(int c = 0; c < input_lengths[mb]; ++c) {
int col_offset = (mb + minibatch_ * c) * alphabet_size_;
for(int r = 0; r < alphabet_size_; ++r) {
grads[r + col_offset] = ProbT(0);
}
}
costs[mb] = ProbT(0);
}
}

}

return CTC_STATUS_SUCCESS;
Expand Down Expand Up @@ -475,7 +518,12 @@ ctcStatus_t CpuCTC<ProbT>::score_forward(const ProbT* const activations,
//labels w/blanks, e_inc, s_inc
per_minibatch_bytes += 3 * sizeof(int) * maxS;

softmax(activations, probs, input_lengths);
if (use_softmax_) {
softmax(activations, probs, input_lengths);
} else {
// since the later computation use log probabilities, here we exp the probs first and then log it will lead to probs itself
exp(activations, probs, input_lengths);
}

#pragma omp parallel for
for (int mb = 0; mb < minibatch_; ++mb) {
Expand All @@ -495,7 +543,11 @@ ctcStatus_t CpuCTC<ProbT>::score_forward(const ProbT* const activations,
ctcm.e_inc, ctcm.s_inc, ctcm.labels_w_blanks,
ctcm.alphas);
}

if (zero_infinity_) {
if (costs[mb] == -ctc_helper::neg_inf<ProbT>()) {
costs[mb] = ProbT(0);
}
}
}

return CTC_STATUS_SUCCESS;
Expand Down
36 changes: 29 additions & 7 deletions include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ class GpuCTC {
int minibatch,
void *workspace,
GPUstream stream,
int blank_label) :
int blank_label,
bool use_softmax,
bool zero_infinity) :
out_dim_(alphabet_size), minibatch_(minibatch),
gpu_workspace_(workspace), stream_(stream),
blank_label_(blank_label) {};
blank_label_(blank_label), use_softmax_(use_softmax),
zero_infinity_(zero_infinity) {};

// Noncopyable
GpuCTC(const GpuCTC&) = delete;
Expand Down Expand Up @@ -77,6 +80,8 @@ class GpuCTC {

int out_dim_; // Number of characters plus blank
int minibatch_;
bool use_softmax_;
bool zero_infinity_;

int S_;
int T_;
Expand Down Expand Up @@ -437,7 +442,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {

prepare_stable_SM_kernel<ProbT, VT> <<< grid_size, NT, 0, stream_>>>
(ctc_helper::identity<ProbT>(), probs_,
denoms_, out_dim_, num_elements);
denoms_, out_dim_, num_elements, use_softmax_);

// Reduce along columns to calculate denominator
ctc_status =
Expand All @@ -449,10 +454,10 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {
// Kernel launch to calculate probabilities
compute_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(ctc_helper::exponential<ProbT>(), probs_,
denoms_, out_dim_, num_elements);
denoms_, out_dim_, num_elements, use_softmax_);

truncate_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(probs_, num_elements);
// truncate_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
// (probs_, num_elements);

return CTC_STATUS_SUCCESS;
}
Expand Down Expand Up @@ -494,7 +499,24 @@ GpuCTC<ProbT>::compute_cost_and_score(const ProbT* const activations,
sizeof(ProbT) * minibatch_,
cudaMemcpyDeviceToHost, stream_);
#endif


if (zero_infinity_){
// zero infinity cost and associated grads
const int NT = 128;
const int VT = 1;
const int NV = NT * VT;
for (int mb = 0; mb < minibatch_; ++mb) {
if (costs[mb] == -ctc_helper::neg_inf<ProbT>()) {
int loc_T = input_lengths[mb];
int num_elements = out_dim_ * loc_T;
int grid_size = ctc_helper::div_up(num_elements, NV);
zero_infinity_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(grads, mb, out_dim_, minibatch_, num_elements);
costs[mb] = ProbT(0);
}
}
}

#ifdef __HIPCC__
cuda_status_sync = hipStreamSynchronize(stream_);
#else
Expand Down
62 changes: 50 additions & 12 deletions include/detail/gpu_ctc_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,30 @@ template <typename ProbT, int VT = 1, typename Op>
__global__ void compute_probs_kernel(Op f, ProbT* probs,
const ProbT* const denom,
int alphabet_size,
int count) {
int count,
bool use_softmax) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (use_softmax) {
#pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
const int column_idx = idx / alphabet_size;
probs[idx] = f(probs[idx]) / denom[column_idx];
for(int i = 0; i < VT; i++) {
if (idx < count) {
const int column_idx = idx / alphabet_size;
probs[idx] = f(probs[idx]) / denom[column_idx];
}
idx += stride;
}
} else {
#pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
probs[idx] = f(probs[idx]);
}
idx += stride;
}
idx += stride;
}

}

template <typename ProbT, int VT = 1>
Expand All @@ -490,16 +502,42 @@ template <typename ProbT, int VT = 1, typename Op>
__global__ void prepare_stable_SM_kernel(Op f, ProbT* probs,
const ProbT* const col_max,
int alphabet_size,
int count) {
int count,
bool use_softmax) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;

if (use_softmax) {
#pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
const int column_idx = idx / alphabet_size;
probs[idx] = f(probs[idx] - col_max[column_idx]);
for(int i = 0; i < VT; i++) {
if (idx < count) {
const int column_idx = idx / alphabet_size;
probs[idx] = f(probs[idx] - col_max[column_idx]);
}
idx += stride;
}
idx += stride;
}
}

template <typename ProbT, int VT = 1>
__global__ void zero_infinity_kernel(ProbT* grads,
int mb,
int alphabet_size,
int minibatch,
int count) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int col = 0;
int row = 0;
# pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
col = idx / alphabet_size;
row = idx % alphabet_size;
grads[(col * minibatch + mb) * alphabet_size + row] = ProbT(0);
}
idx += stride;
}
}
8 changes: 4 additions & 4 deletions src/ctc_entrypoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ctcStatus_t compute_ctc_loss(const float* const activations,

if (options.loc == CTC_CPU) {
CpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);
options.blank_label, options.use_softmax, options.zero_infinity);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
Expand All @@ -72,7 +72,7 @@ ctcStatus_t compute_ctc_loss(const float* const activations,
} else if (options.loc == CTC_GPU) {
#if (defined(__HIPCC__) || defined(__CUDACC__))
GpuCTC<float> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
options.blank_label, options.use_softmax, options.zero_infinity);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
Expand Down Expand Up @@ -112,7 +112,7 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations,

if (options.loc == CTC_CPU) {
CpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.num_threads,
options.blank_label);
options.blank_label, options.use_softmax, options.zero_infinity);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients,
Expand All @@ -125,7 +125,7 @@ ctcStatus_t compute_ctc_loss_double(const double* const activations,
} else if (options.loc == CTC_GPU) {
#if (defined(__HIPCC__) || defined(__CUDACC__))
GpuCTC<double> ctc(alphabet_size, minibatch, workspace, options.stream,
options.blank_label);
options.blank_label, options.use_softmax, options.zero_infinity);

if (gradients != NULL)
return ctc.cost_and_grad(activations, gradients, costs,
Expand Down
Loading