Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add graph-based config and model-- ultragcn #251

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
601 changes: 601 additions & 0 deletions easy_rec/python/input/graph_input.py

Large diffs are not rendered by default.

122 changes: 122 additions & 0 deletions easy_rec/python/model/ultragcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.model.easy_rec_model import EasyRecModel

from easy_rec.python.protos.ultragcn_pb2 import ULTRAGCN as ULTRAGCNConfig # NOQA

if tf.__version__ >= '2.0':
tf = tf.compat.v1


class ULTRAGCN(EasyRecModel):

def __init__(self,
model_config,
feature_configs,
features,
labels=None,
is_training=False):
super(ULTRAGCN, self).__init__(model_config, feature_configs, features, labels,
is_training)
self._model_config = model_config.ultragcn
assert isinstance(self._model_config, ULTRAGCNConfig)
self._user_num = self._model_config.user_num
self._item_num = self._model_config.item_num
self._emb_dim = self._model_config.output_dim
self._i2i_weight = self._model_config.i2i_weight
self._neg_weight = self._model_config.neg_weight
self._l2_weight = self._model_config.l2_weight
self._user_emb = None
self._item_emb = None

if features.get('features') is not None:
self._user_ids = features.get('features')[0]
self._user_degrees = features.get('features')[1]
self._item_ids = features.get('features')[2]
self._item_degrees = features.get('features')[3]
self._nbr_ids = features.get('features')[4]
self._nbr_weights = features.get('features')[5]
self._neg_ids = features.get('features')[6]
else:
self._user_ids = features.get('id')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边都是叫同名的id?

self._user_degrees = None
self._item_ids = features.get('id')
self._item_degrees = None
self._nbr_ids = features.get('id')
self._nbr_weights = None
self._neg_ids = features.get('id')

_user_emb = tf.get_variable("user_emb",
[self._user_num, self._emb_dim],
trainable=True)
_item_emb = tf.get_variable("item_emb",
[self._item_num, self._emb_dim],
trainable=True)

self._user_emb = tf.convert_to_tensor(_user_emb)
self._item_emb = tf.convert_to_tensor(_item_emb)

def build_predict_graph(self):
user_emb = tf.nn.embedding_lookup(self._user_emb, self._user_ids)
item_emb = tf.nn.embedding_lookup(self._item_emb, self._item_ids)
nbr_emb = tf.nn.embedding_lookup(self._item_emb, self._nbr_ids)
neg_emb = tf.nn.embedding_lookup(self._item_emb, self._neg_ids)
self._prediction_dict['user_emb'] = user_emb
self._prediction_dict['item_emb'] = item_emb
self._prediction_dict['nbr_emb'] = nbr_emb
self._prediction_dict['neg_emb'] = neg_emb
self._prediction_dict['user_embedding'] = tf.reduce_join(
tf.as_string(user_emb), axis=-1, separator=',')
self._prediction_dict['item_embedding'] = tf.reduce_join(
tf.as_string(item_emb), axis=-1, separator=',')

return self._prediction_dict

def build_loss_graph(self):
# UltraGCN base u2i
pos_logit = tf.reduce_sum(self._prediction_dict['user_emb'] * self._prediction_dict['item_emb'], axis=-1)
true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(pos_logit), logits=pos_logit)
neg_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['neg_emb'], axis=-1)
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(neg_logit), logits=neg_logit)
loss_u2i = tf.reduce_sum(true_xent * (1 + 1 / tf.sqrt(self._user_degrees * self._item_degrees))) \
+ self._neg_weight * tf.reduce_sum(tf.reduce_mean(negative_xent, axis=-1))
# UltraGCN i2i
nbr_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['nbr_emb'], axis=-1) # [batch_size, nbr_num]
nbr_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(nbr_logit), logits=nbr_logit)
loss_i2i = tf.reduce_sum(nbr_xent * (1 + self._nbr_weights))
# regularization
loss_l2 = tf.nn.l2_loss(self._prediction_dict['user_emb']) + tf.nn.l2_loss(self._prediction_dict['item_emb']) +\
tf.nn.l2_loss(self._prediction_dict['nbr_emb']) + tf.nn.l2_loss(self._prediction_dict['neg_emb'])

loss = loss_u2i + self._i2i_weight * loss_i2i + self._l2_weight * loss_l2
return {'cross_entropy': loss}

def build_metric_graph(self, eval_config):
return {}

def get_outputs(self):
# emb_1 = tf.reduce_join(tf.as_string(self._prediction_dict['user_embedding']), axis=-1, separator=',')
# emb_2 = tf.reduce_join(tf.as_string(self._prediction_dict['item_embedding'] ), axis=-1, separator=',')
return ['user_embedding','item_embedding']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议跟向量召回保持一致,user_emb, item_emb



def build_metric_graph(self, eval_config):
metric_dict = {}
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'recall_at_topk':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric会生效么?logits来自哪里?

logits = self._prediction_dict['logits']
label = tf.zeros_like(logits[:, :1], dtype=tf.int64)
metric_dict['recall_at_top%d' %
metric.recall_at_topk.topk] = metrics.recall_at_k(
label, logits, metric.recall_at_topk.topk)
else:
ValueError('invalid metric type: %s' % str(metric))
return metric_dict
8 changes: 8 additions & 0 deletions easy_rec/python/protos/data_source.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ message BinaryDataInput {
repeated string dense_path = 2;
repeated string label_path = 3;
}

message GraphLearnInput {
optional string user_node_input = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些是可枚举的么?node_name, node_input,这种kv的形式是不是通用一些?

optional string item_node_input = 2;
optional string u2i_edge_input = 3;
optional string i2i_edge_input = 4;
optional string u2u_edge_input = 5;
}
26 changes: 25 additions & 1 deletion easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ message HardNegativeSamplerV2 {
optional string field_delimiter = 12 [default="\001"];
}

// for graph based algorithm: ultragcn.
message UltraGCNSampler {

optional uint32 user_num = 1 [default=10];

optional uint32 item_num = 2 [default=10];

optional uint32 output_dim = 3 [default=10];

optional uint32 nbr_num = 4 [default=10];

optional uint32 neg_num = 5 [default=10];

optional float neg_weight = 6 [default=10];

optional float i2i_weight = 7 [default=10];

optional float l2_weight = 8 [default=10];

optional string neg_sampler = 9 [default = 'random'];

}
message DatasetConfig {
// mini batch size to use for training and evaluation.
optional uint32 batch_size = 1 [default = 32];
Expand Down Expand Up @@ -218,6 +240,7 @@ message DatasetConfig {
HiveRTPInput = 17;
HiveParquetInput = 18;
CriteoInput = 1001;
GraphInput = 19;
}
required InputType input_type = 10;

Expand Down Expand Up @@ -288,6 +311,7 @@ message DatasetConfig {
NegativeSamplerInMemory negative_sampler_in_memory = 105;
}
optional uint32 eval_batch_size = 1001 [default = 4096];


optional UltraGCNSampler ultra_gcn_sampler = 26;

}
2 changes: 2 additions & 0 deletions easy_rec/python/protos/easy_rec_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import "easy_rec/python/protos/dcn.proto";
import "easy_rec/python/protos/cmbf.proto";
import "easy_rec/python/protos/autoint.proto";
import "easy_rec/python/protos/mind.proto";
import "easy_rec/python/protos/ultragcn.proto";
import "easy_rec/python/protos/loss.proto";
import "easy_rec/python/protos/rocket_launching.proto";
import "easy_rec/python/protos/variational_dropout.proto";
Expand Down Expand Up @@ -67,6 +68,7 @@ message EasyRecModel {
MIND mind = 202;
DropoutNet dropoutnet = 203;
CoMetricLearningI2I metric_learning = 204;
ULTRAGCN ultragcn = 205;

MMoE mmoe = 301;
ESMM esmm = 302;
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/pipeline.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ message EasyRecConfig {
DatahubServer datahub_train_input = 12;
HiveConfig hive_train_input = 101;
BinaryDataInput binary_train_input = 102;
GraphLearnInput graph_train_input_path = 103;
}
oneof eval_path {
string eval_input_path = 3;
KafkaServer kafka_eval_input = 4;
DatahubServer datahub_eval_input = 13;
HiveConfig hive_eval_input= 201;
BinaryDataInput binary_eval_input = 202;
GraphLearnInput graph_eval_input_path = 203;
}
required string model_dir = 5;

Expand Down
14 changes: 14 additions & 0 deletions easy_rec/python/protos/ultragcn.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
syntax = "proto2";
package protos;

message ULTRAGCN {
optional float l2_regularization = 2 [default=0.0];
optional uint32 user_num = 3 [default=1];
optional uint32 item_num = 4 [default=1];
optional uint32 output_dim = 5 [default=1];
optional uint32 nbr_num = 6 [default=1];
optional uint32 neg_num = 7 [default=1];
optional float neg_weight = 8 [default=0.0];
optional float i2i_weight = 9 [default=0.0];
optional float l2_weight = 10 [default=0.0];
}
55 changes: 55 additions & 0 deletions easy_rec/python/utils/graph_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging

from easy_rec.python.utils import pai_util
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

跟core/sampler.py中的graph init复用


def graph_init(graph, tf_config=None):
if tf_config:
if isinstance(tf_config, str) or isinstance(tf_config, type(u'')):
tf_config = json.loads(tf_config)
if 'ps' in tf_config['cluster']:
# ps mode
logging.info('ps mode')
ps_count = len(tf_config['cluster']['ps'])
evaluator_cnt = 1
# evaluator_cnt = 1 if pai_util.has_evaluator() else 0
# if evaluator_cnt == 0:
# logging.warning(
# 'evaluator is not set as an client in GraphLearn,'
# 'if you actually set evaluator in TF_CONFIG, please do: export'
# ' HAS_EVALUATOR=1.')
task_count = len(tf_config['cluster']['worker']) + 1 + evaluator_cnt
cluster = {'server_count': ps_count, 'client_count': task_count}
if tf_config['task']['type'] in ['chief', 'master']:
graph.init(cluster=cluster, job_name='client', task_index=0)
elif tf_config['task']['type'] == 'worker':
graph.init(
cluster=cluster,
job_name='client',
task_index=tf_config['task']['index'] + 2)
# TODO(hongsheng.jhs): check cluster has evaluator or not?
elif tf_config['task']['type'] == 'evaluator':
graph.init(
cluster=cluster,
job_name='client',
task_index=tf_config['task']['index'] + 1)
elif tf_config['task']['type'] == 'ps':
graph.init(
cluster=cluster,
job_name='server',
task_index=tf_config['task']['index'])
else:
# worker mode
logging.info('worker mode')
task_count = len(tf_config['cluster']['worker']) + 1
if tf_config['task']['type'] in ['chief', 'master']:
graph.init(task_index=0, task_count=task_count)
elif tf_config['task']['type'] == 'worker':
graph.init(
task_index=tf_config['task']['index'] + evaluator_cnt,
task_count=task_count)
else:
# local mode
graph.init()
15 changes: 11 additions & 4 deletions pai_jobs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ def main(argv):
if pipeline_config.fg_json_path:
fg_util.load_fg_json_to_config(pipeline_config)

if FLAGS.edit_config_json:
print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
config_json = yaml.safe_load(FLAGS.edit_config_json)
config_util.edit_config(pipeline_config, config_json)
# if FLAGS.edit_config_json:
# print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
# config_json = yaml.safe_load(FLAGS.edit_config_json)
# config_util.edit_config(pipeline_config, config_json)

if FLAGS.model_dir:
pipeline_config.model_dir = FLAGS.model_dir
Expand Down Expand Up @@ -273,6 +273,13 @@ def main(argv):
print('[run.py] train_tables: %s' % pipeline_config.train_input_path)
print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path)

if FLAGS.edit_config_json:
print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json)
config_json = yaml.safe_load(FLAGS.edit_config_json)
config_util.edit_config(pipeline_config, config_json)
logging.info('edit json complete')
logging.info(pipeline_config)

if FLAGS.fine_tune_checkpoint:
pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint

Expand Down
82 changes: 82 additions & 0 deletions samples/model_config/ultragcn_on_graph.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
model_dir: "experiments/graph_on_ultargcn_ckpt"

graph_train_input_path {
user_node_input: "easy_rec/data/graph_data/gl_user.txt"
item_node_input: "easy_rec/data/graph_data/gl_item.txt"
u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt"
i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt"
}
graph_eval_input_path {
user_node_input: "easy_rec/data/graph_data/gl_user.txt"
item_node_input: "easy_rec/data/graph_data/gl_item.txt"
u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt"
i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt"
}

train_config {
log_step_count_steps: 100
optimizer_config: {
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 1e-3
}
}
}
use_moving_average: false
}
save_checkpoints_steps: 2000
save_summary_steps: 100
sync_replicas: true
num_steps: 20000
}

eval_config {
}

data_config {
input_fields {
input_name: 'id'
input_type: INT64
}
ultra_gcn_sampler {
nbr_num: 10
neg_num: 10
neg_sampler: 'random'
}

batch_size: 512
num_epochs: 10
prefetch_size: 5
input_type: GraphInput
}

feature_config: {
features: {
input_names: 'id'
feature_type: IdFeature
embedding_dim: 128
hash_bucket_size: 100000
}
}

model_config:{
model_class: "ULTRAGCN"
ultragcn {
l2_regularization: 1e-6
user_num: 52643
item_num: 91599
output_dim: 128
nbr_num: 10
neg_num: 10
neg_weight: 10
i2i_weight: 2.75
l2_weight: 1e-4
}
loss_type: SOFTMAX_CROSS_ENTROPY
embedding_regularization: 0.0
}

export_config {

}