Skip to content

Commit

Permalink
apply dense layers to query, value and key
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 29, 2023
1 parent 212d609 commit 118f663
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,15 +1084,15 @@ inline layer_ptr create_multi_head_attention_layer(
create_vector<std::size_t, decltype(create_size_t)>, create_size_t),
get_param(name, "weight_shapes"));
const auto weight_values = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
const auto weights = fplus::zip_with(
const auto weights_and_biases = fplus::zip_with(
[](const std::vector<std::size_t>& shape, const float_vec& values) -> tensor
{
return tensor(
create_tensor_shape_from_dims(shape),
fplus::convert_container<float_vec>(values));
}, weight_shapes, weight_values);
return std::make_shared<multi_head_attention_layer>(name,
num_heads, key_dim, value_dim, use_bias, attention_axes, weights);
num_heads, key_dim, value_dim, use_bias, attention_axes, weights_and_biases);
}

inline std::string get_activation_type(const nlohmann::json& data)
Expand Down
52 changes: 40 additions & 12 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include "fdeep/layers/layer.hpp"
#include "fdeep/layers/dense_layer.hpp"
#include "fdeep/layers/softmax_layer.hpp"

#include <string>
Expand All @@ -20,41 +21,68 @@ class multi_head_attention_layer : public layer
explicit multi_head_attention_layer(const std::string& name,
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
bool use_bias, const std::vector<std::size_t>& attention_axes,
const std::vector<tensor>& saved_weights)
const std::vector<tensor>& weights_and_biases)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim), attention_axes_(attention_axes),
weights_(extract_weights(saved_weights, use_bias)),
biases_(extract_biases(saved_weights, use_bias))
query_dense_(create_dense_layer(weights_and_biases, use_bias, 0, name + "_query_dense")),
value_dense_(create_dense_layer(weights_and_biases, use_bias, 1, name + "_value_dense")),
key_dense_(create_dense_layer(weights_and_biases, use_bias, 2, name + "_key_dense")),
output_dense_(create_dense_layer(weights_and_biases, use_bias, 3, name + "_output_dense"))
{
}
private:
tensors extract_weights(const tensors& saved_weights, bool use_bias)
dense_layer create_dense_layer(
const tensors& weights_and_biases, bool use_bias,
std::size_t index, const std::string& name)
{
return use_bias ? fplus::unweave(saved_weights).first : saved_weights;
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * index];
const std::size_t n = weights.shape().width_ * weights.shape().depth_;
const tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(tensor_shape(n), 1);
return dense_layer(name, n, *weights.as_vector(), *biases.as_vector());
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).second : tensors(); // todo: create biases with zeroes in right shape
return use_bias ? fplus::unweave(saved_weights).second : tensors();
}
protected:
tensors apply_impl(const tensors& input) const override
{
assertion(input.size() == 2 || input.size() == 3, "Invalid number of inputs for MultiHeadAttention layer.");
//const tensor& query = input[0];
//const tensor& value = input[1];
//const tensor& key = input.size() > 2 ? input[2] : value;
const tensor query_raw = input[0];
const tensor value_raw = input[1];
const tensor key_raw = input.size() > 2 ? input[2] : value_raw;
const tensor query = query_dense_.apply({query_raw}).front();
const tensor value = value_dense_.apply({value_raw}).front();
const tensor key = key_dense_.apply({key_raw}).front();
assertion(
query.shape().rank() == 2 &&
value.shape().rank() == 2 &&
key.shape().rank() == 2 &&
query.shape().depth_ == value.shape().depth_ &&
query.shape().depth_ == key.shape().depth_ &&
value.shape().width_ == key.shape().width_,
"Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)."
);
// https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
// https://dmol.pub/dl/attention.html#multi-head-attention-block
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
return input;
const tensor scores = dot_product_tensors(query, transpose(key), std::vector<int>({2, 1}), false);
const tensor distribution = softmax(scores);
const tensor output = dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
return output_dense_.apply({output});
}
std::size_t num_heads_;
std::size_t key_dim_;
std::size_t value_dim_;
std::vector<std::size_t> attention_axes_;
std::vector<tensor> weights_;
std::vector<tensor> biases_;
dense_layer query_dense_;
dense_layer value_dense_;
dense_layer key_dense_;
dense_layer output_dense_;
};

} } // namespace fdeep, namespace internal

0 comments on commit 118f663

Please sign in to comment.