Skip to content

Commit

Permalink
create dense output layer separately
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 31, 2023
1 parent d896c4c commit 3176f89
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
52 changes: 40 additions & 12 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,30 @@ class multi_head_attention_layer : public layer
query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, name + "_query_dense")),
value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, name + "_value_dense")),
key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, name + "_key_dense")),
output_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 3, name + "_output_dense"))
output_dense_(create_output_dense_layer(weights_and_biases, use_bias, name + "_output_dense"))
{
}
private:
std::vector<dense_layer> create_dense_layers(
const tensors& weights_and_biases, bool use_bias, const std::size_t num_heads,
const std::size_t index, const std::string& name)
{
assertion(index <= 2, "Invalid dense layer index.");

const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * index];

tensor weights = weights_and_biases[index_factor * index];
if (index == 3)
weights = permute_tensor(weights, {3, 1, 2});

const std::size_t units = weights.shape().depth_;
const tensor biases = use_bias ?

tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(index == 3 ? tensor_shape(num_heads, 1, units) : tensor_shape(num_heads, units), 0);
const auto weights_per_head =
index == 3 ? tensor_to_tensors_height_slices(weights) : tensor_to_tensors_width_slices(weights);
const auto biases_per_head =
index == 3 ? tensor_to_tensors_height_slices(biases) : tensor_to_tensors_width_slices(biases);
tensor(index == 3 ? tensor_shape(units) : tensor_shape(num_heads, units), 0);

const auto weights_per_head = tensor_to_tensors_width_slices(weights);
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads.");
assertion(biases_per_head.size() == num_heads, "Invalid biases for number of heads.");
const std::vector<dense_layer> dense_layers =
Expand All @@ -60,6 +66,23 @@ class multi_head_attention_layer : public layer
fplus::enumerate(fplus::zip(weights_per_head, biases_per_head)));
return dense_layers;
}
dense_layer create_output_dense_layer(
const tensors& weights_and_biases, bool use_bias, const std::string& name)
{
const std::size_t index_factor = use_bias ? 2 : 1;

tensor weights = weights_and_biases[index_factor * 3];

const std::size_t units = weights.shape().depth_;

tensor biases = use_bias ?
weights_and_biases[index_factor * 3 + 1] :
tensor(tensor_shape(units), 0);

const auto weights_per_head = tensor_to_tensors_width_slices(weights);
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
return dense_layer(name + "_output", units, *weights.as_vector(), *biases.as_vector());
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).second : tensors();
Expand Down Expand Up @@ -89,8 +112,7 @@ class multi_head_attention_layer : public layer
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
const tensor scores = dot_product_tensors(query, transpose(key), std::vector<int>({2, 1}), false);
const tensor distribution = softmax(scores);
const tensor output = dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
return output_dense_[head_index].apply({output}).front(); // todo
return dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
}
protected:
tensors apply_impl(const tensors& input) const override
Expand All @@ -99,16 +121,22 @@ class multi_head_attention_layer : public layer
const tensor query_raw = input[0];
const tensor value_raw = input[1];
const tensor key_raw = input.size() > 2 ? input[2] : value_raw;
return {apply_head(query_raw, value_raw, key_raw, 0)}; // todo: all
const auto outputs = fplus::transform([&](const std::size_t head_idx)
{
return apply_head(query_raw, value_raw, key_raw, head_idx);
}, fplus::numbers<std::size_t>(0, num_heads_));
const tensor merged = concatenate_tensors_depth(outputs);
return output_dense_.apply({merged});
}
std::size_t num_heads_;
std::size_t key_dim_;
std::size_t value_dim_;
std::vector<std::size_t> attention_axes_;
// todo: store each head as a separate object?
std::vector<dense_layer> query_dense_;
std::vector<dense_layer> value_dense_;
std::vector<dense_layer> key_dense_;
std::vector<dense_layer> output_dense_;
dense_layer output_dense_;
};

} } // namespace fdeep, namespace internal
3 changes: 3 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ def get_test_model_exhaustive():
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
Expand Down

0 comments on commit 3176f89

Please sign in to comment.