From e21226b0e86e50830d18e0075445a07685ec56a2 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 2 Jun 2022 15:09:39 +0800 Subject: [PATCH 01/34] MMoe parquet script; Add a mmoe model draft; --- RecommenderSystems/mmoe/README.md | 152 +++++ RecommenderSystems/mmoe/mmoe_train_eval.py | 549 ++++++++++++++++++ RecommenderSystems/mmoe/tools/mmoe_parquet.py | 138 +++++ 3 files changed, 839 insertions(+) create mode 100644 RecommenderSystems/mmoe/README.md create mode 100644 RecommenderSystems/mmoe/mmoe_train_eval.py create mode 100644 RecommenderSystems/mmoe/tools/mmoe_parquet.py diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md new file mode 100644 index 000000000..b68682dc4 --- /dev/null +++ b/RecommenderSystems/mmoe/README.md @@ -0,0 +1,152 @@ +# MMoe + +[Multi-gate Mixture-of-Experts (MMoE)](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) adapts the Mixture-of- Experts (MoE) structure to multi-task learning by sharing the expert submodels across all tasks, while also having a gating network trained to optimize each task. Its model structure is as follows. Based on this structure, this project uses OneFlow distributed deep learning framework to realize training the model in graph mode on the Criteo data set. + +

+ Screen Shot 2022-04-01 at 4 45 22 PM +

+ + +## Directory description + +```txt + +``` + +## Arguments description + +| Argument Name | Argument Explanation | Default Value | +| -------------------------- | ------------------------------------------------------------ | ------------------------ | +| data_dir | the data file directory | *Required Argument* | +| num_train_samples | the number of train samples | *Required Argument* | +| num_val_samples | the number of validation samples | *Required Argument* | +| num_test_samples | the number of test samples | *Required Argument* | +| model_load_dir | model loading directory | None | +| model_save_dir | model saving directory | None | +| save_best_model | save best model or not | False | +| save_initial_model | save initial model parameters or not | False | +| save_model_after_each_eval | save model after each eval or not | False | +| embedding_vec_size | embedding vector size | 16 | +| dnn | dnn hidden units number | 1000,1000,1000,1000,1000 | +| net_dropout | number of minibatch training interations | 0.2 | +| embedding_vec_size | embedding vector size | 16 | +| learning_rate | initial learning rate | 0.001 | +| batch_size | training/evaluation batch size | 10000 | +| train_batches | the maximum number of training batches | 75000 | +| loss_print_interval | interval of printing loss | 100 | +| patience | Number of epochs with no improvement after which learning rate will be reduced | 2 | +| min_delta | threshold for measuring the new optimum, to only focus on significant changes | 1.0e-6 | +| table_size_array | embedding table size array for sparse fields | *Required Argument* | +| persistent_path | path for persistent kv store of embedding | *Required Argument* | +| store_type | OneEmbeddig persistent kv store type: `device_mem`, `cached_host_mem` or `cached_ssd` | `cached_host_mem` | +| cache_memory_budget_mb | size of cache memory budget on each device in megabytes when `store_type` is `cached_host_mem` or `cached_ssd` | 1024 | +| amp | enable Automatic Mixed Precision(AMP) training or not | False | +| loss_scale_policy | loss scale policy for AMP training: `static` or `dynamic` | `static` | +| disable_early_stop | disable early stop or not | False | + + +## Getting Started + +A hands-on guide to train a MMoe model. + +### Environment + +1. Install OneFlow by following the steps in [OneFlow Installation Guide](https://github.com/Oneflow-Inc/oneflow#install-oneflow) or use the command line below. + + ```shell + python3 -m pip install --pre oneflow -f https://staging.oneflow.info/branch/master/cu102 + ``` + +2. Install all other dependencies listed below. + + ```json + psutil + petastorm + pandas + sklearn + ``` + +### Dataset + +**Note**: + +According to [the DeepFM paper](https://arxiv.org/abs/1703.04247), we treat both categorical and continuous features as sparse features. + +> χ may include categorical fields (e.g., gender, location) and continuous fields (e.g., age). Each categorical field is represented as a vec- tor of one-hot encoding, and each continuous field is repre- sented as the value itself, or a vector of one-hot encoding after discretization. + +1. Download the [Criteo Kaggle dataset](https://www.kaggle.com/c/criteo-display-ad-challenge) and then split it using [split_criteo_kaggle.py](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/split_criteo_kaggle.py). + + Note: Same as [the DeepFM_Criteo_x4_001 experiment](https://github.com/openbenchmark/BARS/tree/master/ctr_prediction/benchmarks/DeepFM/DeepFM_criteo_x4_001) in FuxiCTR, only train.txt is used. Also, the dataset is randomly spllitted into 8:1:1 as training set, validation set and test set. The dataset is splitted using StratifiedKFold in sklearn. + + ```shell + python3 split_criteo_kaggle.py --input_dir=/path/to/your/criteo_kaggle --output_dir=/path/to/your/output/dir + ``` + +2. Download spark from https://spark.apache.org/downloads.html and then uncompress the tar file into the directory where you want to install Spark. Ensure the `SPARK_HOME` environment variable points to the directory where the spark is. + +3. launch a spark shell using [launch_spark.sh](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/launch_spark.sh). + + - Modify the SPARK_LOCAL_DIRS as needed + + ```shell + export SPARK_LOCAL_DIRS=/path/to/your/spark/ + ``` + + - Run `bash launch_spark.sh` + +4. load [deepfm_parquet.scala](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/deepfm_parquet.scala) to your spark shell by `:load deepfm_parquet.scala`. + +5. call the `makeDeepfmDataset(srcDir: String, dstDir:String)` function to generate the dataset. + + ```shell + makeDeepfmDataset("/path/to/your/src_dir", "/path/to/your/dst_dir") + ``` + + After generating parquet dataset, dataset information will also be printed. It contains the information about the number of samples and table size array, which is needed when training. + + ```txt + train samples = 36672493 + validation samples = 4584062 + test samples = 4584062 + table size array: + 649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376 + 1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572 + ``` + +### Start Training by Oneflow + +1. Modify the [train_deepfm.sh](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/train_deepfm.sh) as needed. + + ```shell + #!/bin/bash + DEVICE_NUM_PER_NODE=1 + DATA_DIR=/path/to/deepfm_parquet + PERSISTENT_PATH=/path/to/persistent + MODEL_SAVE_DIR=/path/to/model/save/dir + + python3 -m oneflow.distributed.launch \ + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + deepfm_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 10000 \ + --train_batches 75000 \ + --loss_print_interval 100 \ + --dnn "1000,1000,1000,1000,1000" \ + --net_dropout 0.2 \ + --learning_rate 0.001 \ + --embedding_vec_size 16 \ + --num_train_samples 36672493 \ + --num_val_samples 4584062 \ + --num_test_samples 4584062 \ + --model_save_dir $MODEL_SAVE_DIR \ + --save_best_model + ``` + +2. train a DeepFM model by `bash train_deepfm.sh`. diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py new file mode 100644 index 000000000..fca75fca0 --- /dev/null +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -0,0 +1,549 @@ +import argparse +import os +import sys +import glob +import time +import math +import numpy as np +import psutil +import oneflow as flow +import oneflow.nn as nn +from petastorm.reader import make_batch_reader + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + + +def get_args(print_args=True): + def int_list(x): + return list(map(int, x.split(","))) + + def str_list(x): + return list(map(str, x.split(","))) + + parser = argparse.ArgumentParser() + + parser.add_argument("--data_dir", type=str, required=True) + parser.add_argument( + "--num_train_samples", type=int, required=True, help="the number of train samples" + ) + parser.add_argument( + "--num_test_samples", type=int, required=True, help="the number of test samples" + ) + + parser.add_argument("--model_load_dir", type=str, default=None, help="model loading directory") + parser.add_argument("--model_save_dir", type=str, default=None, help="model saving directory") + parser.add_argument( + "--save_initial_model", action="store_true", help="save initial model parameters or not" + ) + parser.add_argument( + "--save_model_after_each_eval", + action="store_true", + help="save model after each eval or not", + ) + + parser.add_argument("--embedding_vec_size", type=int, default=16, help="embedding vector size") + parser.add_argument( + "--dnn", type=int_list, default="1000,1000,1000,1000,1000", help="dnn hidden units number" + ) + parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate") + + parser.add_argument("--lr_factor", type=float, default=0.1) + parser.add_argument("--min_lr", type=float, default=1.0e-6) + parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") + + parser.add_argument( + "--batch_size", type=int, default=10000, help="training/evaluation batch size" + ) + parser.add_argument( + "--train_batches", type=int, default=75000, help="the maximum number of training batches" + ) + parser.add_argument("--loss_print_interval", type=int, default=100, help="") + + parser.add_argument( + "--table_size_array", + type=int_list, + help="embedding table size array for sparse fields", + required=True, + ) + parser.add_argument( + "--persistent_path", type=str, required=True, help="path for persistent kv store" + ) + parser.add_argument( + "--store_type", + type=str, + default="cached_host_mem", + help="OneEmbeddig persistent kv store type: device_mem, cached_host_mem, cached_ssd", + ) + parser.add_argument( + "--cache_memory_budget_mb", + type=int, + default=1024, + help="size of cache memory budget on each device in megabytes when store_type is cached_host_mem or cached_ssd", + ) + + parser.add_argument( + "--amp", action="store_true", help="enable Automatic Mixed Precision(AMP) training or not" + ) + parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") + + args = parser.parse_args() + + if print_args and flow.env.get_rank() == 0: + _print_args(args) + return args + + +def _print_args(args): + """Print arguments.""" + print("------------------------ arguments ------------------------", flush=True) + str_list = [] + for arg in vars(args): + dots = "." * (48 - len(arg)) + str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print("-------------------- end of arguments ---------------------", flush=True) + + +num_dense_fields = 13 +num_sparse_fields = 26 + + +class MMoeDataReader(object): + """A context manager that manages the creation and termination of a + :class:`petastorm.Reader`. + """ + + def __init__( + self, + parquet_file_url_list, + batch_size, + num_epochs=1, + shuffle_row_groups=True, + shard_seed=2019, + shard_count=1, + cur_shard=0, + ): + self.parquet_file_url_list = parquet_file_url_list + self.batch_size = batch_size + self.num_epochs = num_epochs + self.shuffle_row_groups = shuffle_row_groups + self.shard_seed = shard_seed + self.shard_count = shard_count + self.cur_shard = cur_shard + + fields = ["Label"] + fields += [f"I{i+1}" for i in range(num_dense_fields)] + fields += [f"C{i+1}" for i in range(num_sparse_fields)] + self.fields = fields + self.num_fields = len(fields) + + def __enter__(self): + self.reader = make_batch_reader( + self.parquet_file_url_list, + workers_count=2, + shuffle_row_groups=self.shuffle_row_groups, + num_epochs=self.num_epochs, + shard_seed=self.shard_seed, + shard_count=self.shard_count, + cur_shard=self.cur_shard, + ) + self.loader = self.get_batches(self.reader) + return self.loader + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.reader.stop() + self.reader.join() + + def get_batches(self, reader, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + + tail = None + + for rg in reader: + rgdict = rg._asdict() + rglist = [rgdict[field] for field in self.fields] + pos = 0 + if tail is not None: + pos = batch_size - len(tail[0]) + tail = list( + [ + np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))])) + for i in range(self.num_fields) + ] + ) + if len(tail[0]) == batch_size: + label = tail[0] + features = tail[1 : self.num_fields] + tail = None + yield label, np.stack(features, axis=-1) + else: + pos = 0 + continue + while (pos + batch_size) <= len(rglist[0]): + label = rglist[0][pos : pos + batch_size] + features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)] + pos += batch_size + yield label, np.stack(features, axis=-1) + if pos != len(rglist[0]): + tail = [rglist[i][pos:] for i in range(self.num_fields)] + + +def make_census_dataloader(data_path, batch_size, shuffle=True): + """Make a Census-Income Parquet DataLoader. + :return: a context manager when exit the returned context manager, the reader will be closed. + """ + files = ["file://" + name for name in glob.glob(f"{data_path}/*.parquet")] + files.sort() + + world_size = flow.env.get_world_size() + batch_size_per_proc = batch_size // world_size + + return MMoeDataReader( + files, + batch_size_per_proc, + None, # TODO: iterate over all eval dataset + shuffle_row_groups=shuffle, + shard_seed=2019, + shard_count=world_size, + cur_shard=flow.env.get_rank(), + ) + + +class OneEmbedding(nn.Module): + def __init__( + self, + table_name, + embedding_vec_size, + persistent_path, + table_size_array, + store_type, + cache_memory_budget_mb, + size_factor, + ): + assert table_size_array is not None + vocab_size = sum(table_size_array) + + tables = [ + flow.one_embedding.make_table( + flow.one_embedding.make_normal_initializer(mean=0.0, std=1e-4) + ) + for _ in range(len(table_size_array)) + ] + if store_type == "device_mem": + store_options = flow.one_embedding.make_device_mem_store_options( + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, + ) + elif store_type == "cached_host_mem": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_host_mem_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, + ) + elif store_type == "cached_ssd": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_ssd_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, + ) + else: + raise NotImplementedError("not support", store_type) + + super(OneEmbedding, self).__init__() + self.one_embedding = flow.one_embedding.MultiTableEmbedding( + name=table_name, + embedding_dim=embedding_vec_size, + dtype=flow.float, + key_type=flow.int64, + tables=tables, + store_options=store_options, + ) + + def forward(self, ids): + return self.one_embedding.forward(ids) + + +class DNN(nn.Module): + def __init__( + self, in_features, hidden_units, out_features, skip_final_activation=False, dropout=0.0 + ) -> None: + super(DNN, self).__init__() + denses = [] + dropout_rates = [dropout] * len(hidden_units) + [0.0] + use_relu = [True] * len(hidden_units) + [not skip_final_activation] + hidden_units = [in_features] + hidden_units + [out_features] + for idx in range(len(hidden_units) - 1): + denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)) + if use_relu[idx]: + denses.append(nn.ReLU()) + if dropout_rates[idx] > 0: + denses.append(nn.Dropout(p=dropout_rates[idx])) + self.linear_layers = nn.Sequential(*denses) + + for name, param in self.linear_layers.named_parameters(): + if "weight" in name: + nn.init.xavier_normal_(param) + elif "bias" in name: + param.data.fill_(0.0) + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return self.linear_layers(x) + + +class MMoeModule(nn.Module): + def __init__( + self, + embedding_vec_size=128, + dnn=[1024, 1024, 512, 256], + persistent_path=None, + table_size_array=None, + one_embedding_store_type="cached_host_mem", + cache_memory_budget_mb=8192, + dropout=0.2, + ): + super(MMoeModule, self).__init__() + + def forward(self, inputs) -> flow.Tensor: + pass + + +def make_mmoe_module(args): + model = MMoeModule( + embedding_vec_size=args.embedding_vec_size, + dnn=args.dnn, + persistent_path=args.persistent_path, + table_size_array=args.table_size_array, + one_embedding_store_type=args.store_type, + cache_memory_budget_mb=args.cache_memory_budget_mb, + dropout=args.net_dropout, + ) + return model + + +class MMoeValGraph(flow.nn.Graph): + def __init__(self, mmoe_module, amp=False): + super(MMoeValGraph, self).__init__() + self.module = mmoe_module + if amp: + self.config.enable_amp(True) + + def build(self, features): + predicts = self.module(features.to("cuda")) + return predicts.sigmoid() + + +class MMoeTrainGraph(flow.nn.Graph): + def __init__( + self, mmoe_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, + ): + super(MMoeTrainGraph, self).__init__() + self.module = mmoe_module + self.loss = loss + self.add_optimizer(optimizer, lr_sch=lr_scheduler) + self.config.allow_fuse_model_update_ops(True) + self.config.allow_fuse_add_to_output(True) + self.config.allow_fuse_cast_scale(True) + if amp: + self.config.enable_amp(True) + self.set_grad_scaler(grad_scaler) + + def build(self, labels, features): + logits = self.module(features.to("cuda")) + loss = self.loss(logits, labels.to("cuda")) + loss.backward() + return loss.to("cpu") + + +def make_lr_scheduler(args, optimizer): + batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) + milestones = [ + batches_per_epoch * (i + 1) + for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) + ] + multistep_lr = flow.optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, milestones=milestones, gamma=args.lr_factor, + ) + + return multistep_lr + + +def train(args): + rank = flow.env.get_rank() + + mmoe_module = make_mmoe_module(args) + mmoe_module.to_global(flow.env.all_device_placement("cuda"), flow.sbp.broadcast) + + def load_model(dir): + if rank == 0: + print(f"Loading model from {dir}") + if os.path.exists(dir): + state_dict = flow.load(dir, global_src_rank=0) + mmoe_module.load_state_dict(state_dict, strict=False) + else: + if rank == 0: + print(f"Loading model from {dir} failed: invalid path") + + if args.model_load_dir: + load_model(args.model_load_dir) + + def save_model(subdir): + if not args.model_save_dir: + return + save_path = os.path.join(args.model_save_dir, subdir) + if rank == 0: + print(f"Saving model to {save_path}") + state_dict = mmoe_module.state_dict() + flow.save(state_dict, save_path, global_dst_rank=0) + + if args.save_initial_model: + save_model("initial_checkpoint") + + # TODO: clip gradient norm + opt = flow.optim.Adam(mmoe_module.parameters(), lr=args.learning_rate) + lr_scheduler = make_lr_scheduler(args, opt) + loss = flow.nn.BCEWithLogitsLoss(reduction="mean").to("cuda") + + if args.loss_scale_policy == "static": + grad_scaler = flow.amp.StaticGradScaler(1024) + else: + grad_scaler = flow.amp.GradScaler( + init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + ) + + eval_graph = MmoeValGraph(mmoe_module, args.amp) + train_graph = MMoeTrainGraph( + mmoe_module, loss, opt, grad_scaler, args.amp, lr_scheduler=lr_scheduler + ) + + batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) + + cached_eval_batches = prefetch_eval_batches( + f"{args.data_dir}/test", args.batch_size, math.ceil(args.num_test_samples / args.batch_size) + ) + + mmoe_module.train() + epoch = 0 + with make_census_dataloader(f"{args.data_dir}/train", args.batch_size) as loader: + step, last_step, last_time = -1, 0, time.time() + for step in range(1, args.train_batches + 1): + labels, features = batch_to_global(*next(loader)) + loss = train_graph(labels, features) + if step % args.loss_print_interval == 0: + loss = loss.numpy() + if rank == 0: + latency = (time.time() - last_time) / (step - last_step) + throughput = args.batch_size / latency + last_step, last_time = step, time.time() + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, " + + f"Latency {(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" + ) + + if step % batches_per_epoch == 0: + epoch += 1 + auc = eval( + args, + eval_graph, + cur_step=step, + epoch=epoch, + cached_eval_batches=cached_eval_batches, + ) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + + mmoe_module.train() + last_time = time.time() + + if step % batches_per_epoch != 0: + auc = eval( + args, + eval_graph, + cur_step=step, + epoch=epoch, + cached_eval_batches=cached_eval_batches, + ) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + + +def np_to_global(np): + t = flow.from_numpy(np) + return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + + +def batch_to_global(np_label, np_features, is_train=True): + labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + features = np_to_global(np_features) + return labels, features + + +def prefetch_eval_batches(data_dir, batch_size, num_batches): + cached_eval_batches = [] + with make_census_dataloader(data_dir, batch_size, shuffle=False) as loader: + for _ in range(num_batches): + label, features = batch_to_global(*next(loader), is_train=False) + cached_eval_batches.append((label, features)) + return cached_eval_batches + + +def eval(args, eval_graph, cur_step=0, epoch=0, cached_eval_batches=None): + batches_per_epoch = math.ceil(args.num_test_samples / args.batch_size) + + eval_graph.module.eval() + labels, preds = [], [] + eval_start_time = time.time() + + for i in range(batches_per_epoch): + label, features = cached_eval_batches[i] + pred = eval_graph(features) + labels.append(label) + preds.append(pred.to_local()) + + labels = ( + np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() + ) + preds = ( + flow.cat(preds, dim=0) + .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() + ) + + flow.comm.barrier() + eval_time = time.time() - eval_start_time + + rank = flow.env.get_rank() + + metrics_start_time = time.time() + auc = flow.roc_auc_score(labels, preds).numpy()[0] + logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean") + metrics_time = time.time() - metrics_start_time + + if rank == 0: + host_mem_mb = psutil.Process().memory_info().rss // (1024 * 1024) + stream = os.popen("nvidia-smi --query-gpu=memory.used --format=csv") + device_mem_str = stream.read().split("\n")[rank + 1] + + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Epoch {epoch}, Step {cur_step}, AUC {auc:0.6f}, LogLoss {logloss:0.6f}, " + + f"Eval_time {eval_time:0.2f} s, Metrics_time {metrics_time:0.2f} s, Eval_samples {labels.shape[0]}, " + + f"GPU_Memory {device_mem_str}, Host_Memory {host_mem_mb} MiB, {strtime}" + ) + + return auc + + +if __name__ == "__main__": + os.system(sys.executable + " -m oneflow --doctor") + flow.boxing.nccl.enable_all_to_all(True) + args = get_args() + train(args) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py new file mode 100644 index 000000000..57582fa1b --- /dev/null +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -0,0 +1,138 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import time +import argparse + +import pandas as pd +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder, MinMaxScaler + +from pyspark.sql import SparkSession +from pyspark.conf import SparkConf +from pyspark.sql.functions import rand, udf, lit, xxhash64 +from pyspark.sql.types import FloatType, LongType + +column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', + 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', + 'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends', + 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', + 'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', + 'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', + 'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k'] + +sparse_features = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code', + 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason', + 'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', + 'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', + 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', + 'vet_question'] + +def make_mmoe_parquet( + spark, input_files, output_dir, part_num=None, shuffle=False +): + + data = pd.read_csv(input_files, header=None, names=column_names) + + data['label_income'] = data['income_50k'].map({' - 50000.': 0, ' 50000+.': 1}) + data['label_marital'] = data['marital_stat'].apply(lambda x: 1 if x == ' Never married' else 0) + data.drop(labels=['income_50k', 'marital_stat'], axis=1, inplace=True) + + columns = data.columns.values.tolist() + + dense_features = [col for col in columns if + col not in sparse_features and col not in ['label_income', 'label_marital']] + + data[sparse_features] = data[sparse_features].fillna('-1', ) + data[dense_features] = data[dense_features].fillna(0, ) + mms = MinMaxScaler(feature_range=(0, 1)) + data[dense_features] = mms.fit_transform(data[dense_features]) + + start = time.time() + + df = spark.createDataFrame(data) + columns_new = dense_features + sparse_features + ["label_income", "label_marital"] + df = df.select(columns_new) + + make_label = udf(lambda s: float(s), FloatType()) + label_cols = [make_label(field).alias(field) for field in ["label_income", "label_marital"]] + + sparse_cols = [xxhash64(field, lit(i)).alias(field) for i, field in enumerate(sparse_features)] + + make_dense = udf(lambda s: float(s), FloatType()) + dense_cols = [make_dense(field).alias(field) for field in dense_features] + + cols = dense_cols + sparse_cols + label_cols + df = df.select(cols) + + if shuffle: + df = df.orderBy(rand()) + if part_num: + df = df.repartition(part_num) + + df.write.mode("overwrite").parquet(output_dir) + num_examples = spark.read.parquet(output_dir).count() + print(output_dir, num_examples, f"time elapsed: {time.time()-start:0.1f}") + return num_examples, columns_new + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Path to downloaded and unzipd criteo terabyte datasets: day_0, day_1, ..., day_23", + ) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--spark_tmp_dir", type=str, default=None) + parser.add_argument("--spark_driver_memory_gb", type=int, default=360) + parser.add_argument( + "--export_dataset_info", action="store_true", help="export dataset infomation or not" + ) + args = parser.parse_args() + + test_csv = os.path.join(args.input_dir, "census-income.test") + train_csv = os.path.join(args.input_dir, "census-income.sample") + + # start spark session + conf = SparkConf() + conf.set("spark.driver.memory", f"{args.spark_driver_memory_gb}g") + conf.set("spark.local.dir", args.spark_tmp_dir) + spark = SparkSession.builder.config(conf=conf).master("local[*]").getOrCreate() + + # create test dataset + test_output_dir = os.path.join(args.output_dir, "test") + test_count, _ = make_mmoe_parquet( + spark, test_csv, test_output_dir, part_num=32 + ) + + # create train dataset + train_output_dir = os.path.join(args.output_dir, "train") + train_count, columns = make_mmoe_parquet( + spark, train_csv, train_output_dir, part_num=64, shuffle=True + ) + + if args.export_dataset_info: + df = spark.read.parquet(train_output_dir, test_output_dir) + table_size_array = [df.select(field).distinct().count() for field in sparse_features] + print(table_size_array) + with open(os.path.join(args.output_dir, "README.md"), "w") as f: + f.write("## number of examples:\n") + f.write(f"train: {train_count}\n") + f.write(f"test: {test_count}\n") + f.write("## table size array\n") + f.write("table_size_array = [") + f.write(", ".join([str(i) for i in table_size_array])) + f.write("]\n") From aa5d3568a0ad78cce70ca4b50223bfc2bd97ab85 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 2 Jun 2022 17:15:32 +0800 Subject: [PATCH 02/34] dlrm profile --- .../dlrm/dlrm_prefetch_train.py | 629 ++++++++++++++++++ RecommenderSystems/dlrm/dlrm_profile.py | 41 ++ 2 files changed, 670 insertions(+) create mode 100644 RecommenderSystems/dlrm/dlrm_prefetch_train.py create mode 100644 RecommenderSystems/dlrm/dlrm_profile.py diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py new file mode 100644 index 000000000..88ef64963 --- /dev/null +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -0,0 +1,629 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import argparse +import os +import sys +import glob +import time +import numpy as np +import psutil +import warnings +import oneflow as flow +import oneflow.nn as nn + +warnings.filterwarnings("ignore", category=FutureWarning) +from petastorm.reader import make_batch_reader + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + + +def get_args(print_args=True): + def int_list(x): + return list(map(int, x.split(","))) + + def str_list(x): + return list(map(str, x.split(","))) + + parser = argparse.ArgumentParser() + + parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") + parser.add_argument("--embedding_vec_size", type=int, default=128) + parser.add_argument("--bottom_mlp", type=int_list, default="512,256,128") + parser.add_argument("--top_mlp", type=int_list, default="1024,1024,512,256") + parser.add_argument( + "--disable_interaction_padding", + action="store_true", + help="disable interaction padding or not", + ) + parser.add_argument( + "--interaction_itself", action="store_true", help="interaction itself or not" + ) + parser.add_argument("--model_load_dir", type=str, default=None) + parser.add_argument("--model_save_dir", type=str, default=None) + parser.add_argument( + "--save_initial_model", action="store_true", help="save initial model parameters or not.", + ) + parser.add_argument( + "--save_model_after_each_eval", action="store_true", help="save model after each eval.", + ) + parser.add_argument("--data_dir", type=str, default="/data/criteo1t/criteo1t_dlrm_parquet_40M") + parser.add_argument("--eval_batches", type=int, default=1612, help="number of eval batches") + parser.add_argument("--eval_batch_size", type=int, default=55296) + parser.add_argument("--eval_interval", type=int, default=10000) + parser.add_argument("--train_batch_size", type=int, default=55296) + parser.add_argument("--learning_rate", type=float, default=24) + parser.add_argument("--warmup_batches", type=int, default=2750) + parser.add_argument("--decay_batches", type=int, default=27772) + parser.add_argument("--decay_start", type=int, default=49315) + parser.add_argument("--train_batches", type=int, default=300) + parser.add_argument("--loss_print_interval", type=int, default=1000) + parser.add_argument( + "--table_size_array", + type=int_list, + default="39884406,39043,17289,7420,20263,3,7120,1543,63,38532951,2953546,403346,10,2208,11938,155,4,976,14,39979771,25641295,39664984,585935,12972,108,36", + help="Embedding table size array for sparse fields", + ) + parser.add_argument( + "--persistent_path", type=str, required=True, help="path for persistent kv store", + ) + parser.add_argument("--store_type", type=str, default="cached_host_mem") + parser.add_argument("--cache_memory_budget_mb", type=int, default=8192) + parser.add_argument("--amp", action="store_true", help="Run model with amp") + parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") + + args = parser.parse_args() + + if print_args and flow.env.get_rank() == 0: + _print_args(args) + return args + + +def _print_args(args): + """Print arguments.""" + print("------------------------ arguments ------------------------", flush=True) + str_list = [] + for arg in vars(args): + dots = "." * (48 - len(arg)) + str_list.append(" {} {} {}".format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + print(arg, flush=True) + print("-------------------- end of arguments ---------------------", flush=True) + + +num_dense_fields = 13 +num_sparse_fields = 26 + + +class DLRMDataReader(object): + """A context manager that manages the creation and termination of a + :class:`petastorm.Reader`. + """ + + def __init__( + self, + parquet_file_url_list, + batch_size, + num_epochs, + shuffle_row_groups=True, + shard_seed=1234, + shard_count=1, + cur_shard=0, + ): + self.parquet_file_url_list = parquet_file_url_list + self.batch_size = batch_size + self.num_epochs = num_epochs + self.shuffle_row_groups = shuffle_row_groups + self.shard_seed = shard_seed + self.shard_count = shard_count + self.cur_shard = cur_shard + + fields = ["label"] + fields += [f"I{i+1}" for i in range(num_dense_fields)] + self.I_end = len(fields) + fields += [f"C{i+1}" for i in range(num_sparse_fields)] + self.C_end = len(fields) + self.fields = fields + + def __enter__(self): + self.reader = make_batch_reader( + self.parquet_file_url_list, + workers_count=1, + shuffle_row_groups=self.shuffle_row_groups, + num_epochs=self.num_epochs, + shard_seed=self.shard_seed, + shard_count=self.shard_count, + cur_shard=self.cur_shard, + ) + self.loader = self.get_batches(self.reader) + return self.loader + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.reader.stop() + self.reader.join() + + def get_batches(self, reader, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + tail = None + for rg in reader: + rgdict = rg._asdict() + rglist = [rgdict[field] for field in self.fields] + pos = 0 + if tail is not None: + pos = batch_size - len(tail[0]) + tail = list( + [ + np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))])) + for i in range(self.C_end) + ] + ) + if len(tail[0]) == batch_size: + label = tail[0] + dense = tail[1 : self.I_end] + sparse = tail[self.I_end : self.C_end] + tail = None + yield label, np.stack(dense, axis=-1), np.stack(sparse, axis=-1) + else: + pos = 0 + continue + while (pos + batch_size) <= len(rglist[0]): + label = rglist[0][pos : pos + batch_size] + dense = [rglist[j][pos : pos + batch_size] for j in range(1, self.I_end)] + sparse = [rglist[j][pos : pos + batch_size] for j in range(self.I_end, self.C_end)] + pos += batch_size + yield label, np.stack(dense, axis=-1), np.stack(sparse, axis=-1) + if pos != len(rglist[0]): + tail = [rglist[i][pos:] for i in range(self.C_end)] + + +class Dense(nn.Module): + def __init__(self, in_features: int, out_features: int, relu=True) -> None: + super(Dense, self).__init__() + self.features = ( + nn.Sequential(nn.Linear(in_features, out_features), nn.ReLU(inplace=True)) + if relu + else nn.Linear(in_features, out_features) + ) + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return self.features(x) + + +class MLP(nn.Module): + def __init__( + self, in_features: int, hidden_units, skip_final_activation=False, fused=True + ) -> None: + super(MLP, self).__init__() + if fused: + self.linear_layers = nn.FusedMLP( + in_features, + hidden_units[:-1], + hidden_units[-1], + skip_final_activation=skip_final_activation, + ) + else: + units = [in_features] + hidden_units + num_layers = len(hidden_units) + denses = [ + Dense(units[i], units[i + 1], not skip_final_activation or (i + 1) < num_layers) + for i in range(num_layers) + ] + self.linear_layers = nn.Sequential(*denses) + + for name, param in self.linear_layers.named_parameters(): + if "weight" in name: + nn.init.normal_(param, 0.0, np.sqrt(2 / sum(param.shape))) + elif "bias" in name: + nn.init.normal_(param, 0.0, np.sqrt(1 / param.shape[0])) + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return self.linear_layers(x) + + +class Interaction(nn.Module): + def __init__( + self, + dense_feature_size, + num_embedding_fields, + interaction_itself=False, + interaction_padding=True, + ): + super(Interaction, self).__init__() + self.interaction_itself = interaction_itself + n_cols = num_embedding_fields + 2 if self.interaction_itself else num_embedding_fields + 1 + output_size = dense_feature_size + sum(range(n_cols)) + self.output_size = ((output_size + 8 - 1) // 8 * 8) if interaction_padding else output_size + self.output_padding = self.output_size - output_size + + def forward(self, x: flow.Tensor, y: flow.Tensor) -> flow.Tensor: + (bsz, d) = x.shape + return flow._C.fused_dot_feature_interaction( + [x.view(bsz, 1, d), y], + output_concat=x, + self_interaction=self.interaction_itself, + output_padding=self.output_padding, + ) + + +class OneEmbedding(nn.Module): + def __init__( + self, + embedding_vec_size, + persistent_path, + table_size_array, + store_type, + cache_memory_budget_mb, + ): + assert table_size_array is not None + vocab_size = sum(table_size_array) + + scales = np.sqrt(1 / np.array(table_size_array)) + tables = [ + flow.one_embedding.make_table( + flow.one_embedding.make_uniform_initializer(low=-scale, high=scale) + ) + for scale in scales + ] + if store_type == "device_mem": + store_options = flow.one_embedding.make_device_mem_store_options( + persistent_path=persistent_path, capacity=vocab_size + ) + elif store_type == "cached_host_mem": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_host_mem_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + ) + elif store_type == "cached_ssd": + assert cache_memory_budget_mb > 0 + store_options = flow.one_embedding.make_cached_ssd_store_options( + cache_budget_mb=cache_memory_budget_mb, + persistent_path=persistent_path, + capacity=vocab_size, + ) + else: + raise NotImplementedError("not support", store_type) + + super(OneEmbedding, self).__init__() + self.one_embedding = flow.one_embedding.MultiTableEmbedding( + "sparse_embedding", + embedding_dim=embedding_vec_size, + dtype=flow.float, + key_type=flow.int64, + tables=tables, + store_options=store_options, + ) + + def forward(self, ids): + return self.one_embedding.forward(ids) + + +class DLRMModule(nn.Module): + def __init__( + self, + embedding_vec_size=128, + bottom_mlp=[512, 256, 128], + top_mlp=[1024, 1024, 512, 256], + use_fusedmlp=True, + persistent_path=None, + table_size_array=None, + one_embedding_store_type="cached_host_mem", + cache_memory_budget_mb=8192, + interaction_itself=True, + interaction_padding=True, + ): + super(DLRMModule, self).__init__() + assert ( + embedding_vec_size == bottom_mlp[-1] + ), "Embedding vector size must equle to bottom MLP output size" + self.bottom_mlp = MLP(num_dense_fields, bottom_mlp, fused=use_fusedmlp) + self.embedding = OneEmbedding( + embedding_vec_size, + persistent_path, + table_size_array, + one_embedding_store_type, + cache_memory_budget_mb, + ) + self.interaction = Interaction( + bottom_mlp[-1], + num_sparse_fields, + interaction_itself, + interaction_padding=interaction_padding, + ) + self.top_mlp = MLP( + self.interaction.output_size, + top_mlp + [1], + skip_final_activation=True, + fused=use_fusedmlp, + ) + + def forward(self, dense_fields, sparse_fields) -> flow.Tensor: + dense_fields = flow.log(dense_fields + 1.0) + dense_fields = self.bottom_mlp(dense_fields) + embedding = self.embedding(sparse_fields) + features = self.interaction(dense_fields, embedding) + return self.top_mlp(features) + + +def make_dlrm_module(args): + model = DLRMModule( + embedding_vec_size=args.embedding_vec_size, + bottom_mlp=args.bottom_mlp, + top_mlp=args.top_mlp, + use_fusedmlp=not args.disable_fusedmlp, + persistent_path=args.persistent_path, + table_size_array=args.table_size_array, + one_embedding_store_type=args.store_type, + cache_memory_budget_mb=args.cache_memory_budget_mb, + interaction_itself=args.interaction_itself, + interaction_padding=not args.disable_interaction_padding, + ) + return model + + +def make_criteo_dataloader(data_path, batch_size, shuffle=True): + """Make a Criteo Parquet DataLoader. + :return: a context manager when exit the returned context manager, the reader will be closed. + """ + files = ["file://" + name for name in glob.glob(f"{data_path}/*.parquet")] + files.sort() + + world_size = flow.env.get_world_size() + batch_size_per_proc = batch_size // world_size + + return DLRMDataReader( + files, + batch_size_per_proc, + None, # TODO: iterate over all eval dataset + shuffle_row_groups=shuffle, + shard_seed=1234, + shard_count=world_size, + cur_shard=flow.env.get_rank(), + ) + + +def make_lr_scheduler(args, optimizer): + warmup_lr = flow.optim.lr_scheduler.LinearLR( + optimizer, start_factor=0, total_iters=args.warmup_batches, + ) + poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR( + optimizer, steps=args.decay_batches, end_learning_rate=0, power=2.0, cycle=False, + ) + sequential_lr = flow.optim.lr_scheduler.SequentialLR( + optimizer=optimizer, + schedulers=[warmup_lr, poly_decay_lr], + milestones=[args.decay_start], + interval_rescaling=True, + ) + return sequential_lr + + +class DLRMValGraph(flow.nn.Graph): + def __init__(self, dlrm_module, amp=False): + super(DLRMValGraph, self).__init__() + self.module = dlrm_module + if amp: + self.config.enable_amp(True) + + def build(self, dense_fields, sparse_fields): + predicts = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) + return predicts.sigmoid() + + +class DLRMTrainGraph(flow.nn.Graph): + def __init__( + self, dlrm_module, loss, optimizer, lr_scheduler=None, grad_scaler=None, amp=False, + ): + super(DLRMTrainGraph, self).__init__() + self.module = dlrm_module + self.loss = loss + self.add_optimizer(optimizer, lr_sch=lr_scheduler) + self.config.allow_fuse_model_update_ops(True) + self.config.allow_fuse_add_to_output(True) + self.config.allow_fuse_cast_scale(True) + if amp: + self.config.enable_amp(True) + self.set_grad_scaler(grad_scaler) + + def build(self, labels, dense_fields, sparse_fields): + logits = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) + loss = self.loss(logits, labels.to("cuda")) + reduce_loss = flow.mean(loss) + reduce_loss.backward() + return reduce_loss.to("cpu") + + +def prefetch_eval_batches(data_dir, batch_size, num_batches): + cached_eval_batches = [] + with make_criteo_dataloader(data_dir, batch_size, shuffle=False) as loader: + for _ in range(num_batches): + label, dense_fields, sparse_fields = batch_to_global(*next(loader), is_train=False) + cached_eval_batches.append((label, dense_fields, sparse_fields)) + return cached_eval_batches + + +def train(args): + rank = flow.env.get_rank() + + dlrm_module = make_dlrm_module(args) + dlrm_module.to_global(flow.env.all_device_placement("cuda"), flow.sbp.broadcast) + + if args.model_load_dir: + print(f"Loading model from {args.model_load_dir}") + state_dict = flow.load(args.model_load_dir, global_src_rank=0) + dlrm_module.load_state_dict(state_dict, strict=False) + + def save_model(subdir): + if not args.model_save_dir: + return + save_path = os.path.join(args.model_save_dir, subdir) + if rank == 0: + print(f"Saving model to {save_path}") + state_dict = dlrm_module.state_dict() + flow.save(state_dict, save_path, global_dst_rank=0) + + if args.save_initial_model: + save_model("initial_checkpoint") + + opt = flow.optim.SGD(dlrm_module.parameters(), lr=args.learning_rate) + lr_scheduler = make_lr_scheduler(args, opt) + loss = flow.nn.BCEWithLogitsLoss(reduction="none").to("cuda") + + if args.loss_scale_policy == "static": + grad_scaler = flow.amp.StaticGradScaler(1024) + else: + grad_scaler = flow.amp.GradScaler( + init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + ) + + eval_graph = DLRMValGraph(dlrm_module, args.amp) + train_graph = DLRMTrainGraph(dlrm_module, loss, opt, lr_scheduler, grad_scaler, args.amp) + + cached_eval_batches = prefetch_eval_batches( + f"{args.data_dir}/test", args.eval_batch_size, args.eval_batches + ) + + dlrm_module.train() + # with make_criteo_dataloader(f"{args.data_dir}/train", args.train_batch_size) as loader: + # ts = [] + # #labels_0, dense_fields_0, sparse_fields_0 = next(loader) + # for i in range(4000): + # labels, dense_fields, sparse_fields = next(loader) + # #labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) + # #labels, dense_fields, sparse_fields = batch_to_global(labels_0, dense_fields_0, sparse_fields_0) + # + # ts.append(time.time()) + # if rank == 0: + # for t in ts: + # print(t) + # exit() + + with make_criteo_dataloader(f"{args.data_dir}/train", args.train_batch_size) as loader: + print('start prefetch training data...') + cached_batches = [batch_to_global(*next(loader)) for _ in range(args.train_batches)] + print('start training ..') + step, last_step, last_time = 0, 0, time.time() + for labels, dense_fields, sparse_fields in cached_batches: + loss = train_graph(labels, dense_fields, sparse_fields) + step += 1 + if step % args.loss_print_interval == 0: + loss = loss.numpy() + if rank == 0: + latency = (time.time() - last_time) / (step - last_step) + throughput = args.train_batch_size / latency + last_step, last_time = step, time.time() + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, Latency " + + f"{(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" + ) + exit() + + with make_criteo_dataloader(f"{args.data_dir}/train", args.train_batch_size) as loader: + step, last_step, last_time = -1, 0, time.time() + for step in range(1, args.train_batches + 1): + labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) + loss = train_graph(labels, dense_fields, sparse_fields) + if step % args.loss_print_interval == 0: + loss = loss.numpy() + if rank == 0: + latency = (time.time() - last_time) / (step - last_step) + throughput = args.train_batch_size / latency + last_step, last_time = step, time.time() + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, Latency " + + f"{(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" + ) + + if args.eval_interval > 0 and step % args.eval_interval == 0: + auc = eval(cached_eval_batches, eval_graph, step) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + dlrm_module.train() + last_time = time.time() + + if args.eval_interval > 0 and step % args.eval_interval != 0: + auc = eval(cached_eval_batches, eval_graph, step) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_{auc:0.5f}") + + +def np_to_global(np): + t = flow.from_numpy(np) + return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + + +def batch_to_global(np_label, np_dense, np_sparse, is_train=True): + labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + dense_fields = np_to_global(np_dense) + sparse_fields = np_to_global(np_sparse) + return labels, dense_fields, sparse_fields + + +def eval(cached_eval_batches, eval_graph, cur_step=0): + num_eval_batches = len(cached_eval_batches) + if num_eval_batches <= 0: + return + eval_graph.module.eval() + labels, preds = [], [] + eval_start_time = time.time() + for i in range(num_eval_batches): + label, dense_fields, sparse_fields = cached_eval_batches[i] + pred = eval_graph(dense_fields, sparse_fields) + labels.append(label) + preds.append(pred.to_local()) + + labels = ( + np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() + ) + preds = ( + flow.cat(preds, dim=0) + .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() + ) + flow.comm.barrier() + eval_time = time.time() - eval_start_time + + rank = flow.env.get_rank() + auc = 0 + if rank == 0: + auc_start_time = time.time() + auc = flow.roc_auc_score(labels, preds).numpy()[0] + auc_time = time.time() - auc_start_time + host_mem_mb = psutil.Process().memory_info().rss // (1024 * 1024) + stream = os.popen("nvidia-smi --query-gpu=memory.used --format=csv") + device_mem_str = stream.read().split("\n")[rank + 1] + + strtime = time.strftime("%Y-%m-%d %H:%M:%S") + print( + f"Rank[{rank}], Step {cur_step}, AUC {auc:0.5f}, Eval_time {eval_time:0.2f} s, " + + f"AUC_time {auc_time:0.2f} s, Eval_samples {labels.shape[0]}, " + + f"GPU_Memory {device_mem_str}, Host_Memory {host_mem_mb} MiB, {strtime}" + ) + + return auc + + +if __name__ == "__main__": + os.system(sys.executable + " -m oneflow --doctor") + os.system("env") + flow.boxing.nccl.enable_all_to_all(True) + args = get_args() + + train(args) diff --git a/RecommenderSystems/dlrm/dlrm_profile.py b/RecommenderSystems/dlrm/dlrm_profile.py new file mode 100644 index 000000000..fee5f4f2c --- /dev/null +++ b/RecommenderSystems/dlrm/dlrm_profile.py @@ -0,0 +1,41 @@ +import os +import sys + +test_name = "dlrm_profile" +nsys = '/usr/local/cuda-11.6/bin/nsys profile --stats=true ' +#nsys = '/usr/local/cuda-11.5/bin/nsys profile --stats=true ' + +data_dir = "/data/criteo1t/criteo1t_dlrm_parquet" +persistent_path = './persistent' +script_path = 'dlrm_train_eval.py' +#script_path = 'dlrm_prefetch_train.py' + +env = '' +#env += "NCCL_DEBUG=INFO " +#env += "ONEFLOW_DEBUG_MODE=INFO " +env += "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 " +env += "ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 " +env += "ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 " + +dl = sys.executable + " -m oneflow.distributed.launch " +dl += "--nproc_per_node 4 " +dl += "--nnodes 1 " +dl += "--node_rank 0 " +dl += "--master_addr 127.0.0.1 " +dl += f"{script_path} " + +cfg = "" +cfg += "--train_batches 300 " +cfg += "--eval_interval 0 " +cfg += f"--persistent_path {persistent_path} " +cfg += f"--data_dir {data_dir} " +cfg += "--store_type device_mem " +cfg += "--amp " + + +cmd = dl + cfg +cmd = nsys + f"-o {test_name} " + dl + cfg +os.system(f'rm -rf {persistent_path}/*') +os.system("echo " + env + cmd + f" | tee {test_name}.log") +os.system(env + cmd + f" | tee {test_name}.log") + From a628a8215e44d5845188a86c3239cb89ea4a69c1 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 2 Jun 2022 17:18:03 +0800 Subject: [PATCH 03/34] update --- RecommenderSystems/dlrm/dlrm_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_profile.py b/RecommenderSystems/dlrm/dlrm_profile.py index fee5f4f2c..c08f464b0 100644 --- a/RecommenderSystems/dlrm/dlrm_profile.py +++ b/RecommenderSystems/dlrm/dlrm_profile.py @@ -7,8 +7,8 @@ data_dir = "/data/criteo1t/criteo1t_dlrm_parquet" persistent_path = './persistent' -script_path = 'dlrm_train_eval.py' -#script_path = 'dlrm_prefetch_train.py' +#script_path = 'dlrm_train_eval.py' +script_path = 'dlrm_prefetch_train.py' env = '' #env += "NCCL_DEBUG=INFO " From c430e5f4243d8966792e555b00178739cf47999d Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 8 Jun 2022 12:36:53 +0800 Subject: [PATCH 04/34] Add Mmoe dataloader; Add MmoeModule; --- RecommenderSystems/mmoe/mmoe_train_eval.py | 331 ++++++++++++++++----- 1 file changed, 251 insertions(+), 80 deletions(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index fca75fca0..49c615ddd 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -105,11 +105,11 @@ def _print_args(args): print("-------------------- end of arguments ---------------------", flush=True) -num_dense_fields = 13 -num_sparse_fields = 26 +num_dense_fields = 11 +num_sparse_fields = 29 -class MMoeDataReader(object): +class MmoeDataReader(object): """A context manager that manages the creation and termination of a :class:`petastorm.Reader`. """ @@ -132,11 +132,93 @@ def __init__( self.shard_count = shard_count self.cur_shard = cur_shard - fields = ["Label"] - fields += [f"I{i+1}" for i in range(num_dense_fields)] - fields += [f"C{i+1}" for i in range(num_sparse_fields)] - self.fields = fields - self.num_fields = len(fields) + column_names = [ + "age", + "class_worker", + "det_ind_code", + "det_occ_code", + "education", + "wage_per_hour", + "hs_college", + "marital_stat", + "major_ind_code", + "major_occ_code", + "race", + "hisp_origin", + "sex", + "union_member", + "unemp_reason", + "full_or_part_emp", + "capital_gains", + "capital_losses", + "stock_dividends", + "tax_filer_stat", + "region_prev_res", + "state_prev_res", + "det_hh_fam_stat", + "det_hh_summ", + "instance_weight", + "mig_chg_msa", + "mig_chg_reg", + "mig_move_reg", + "mig_same", + "mig_prev_sunbelt", + "num_emp", + "fam_under_18", + "country_father", + "country_mother", + "country_self", + "citizenship", + "own_or_self", + "vet_question", + "vet_benefits", + "weeks_worked", + "year", + "income_50k", + ] + + sparse_features = [ + "class_worker", + "det_ind_code", + "det_occ_code", + "education", + "hs_college", + "major_ind_code", + "major_occ_code", + "race", + "hisp_origin", + "sex", + "union_member", + "unemp_reason", + "full_or_part_emp", + "tax_filer_stat", + "region_prev_res", + "state_prev_res", + "det_hh_fam_stat", + "det_hh_summ", + "mig_chg_msa", + "mig_chg_reg", + "mig_move_reg", + "mig_same", + "mig_prev_sunbelt", + "fam_under_18", + "country_father", + "country_mother", + "country_self", + "citizenship", + "vet_question", + ] + + dense_features = [ + col + for col in column_names + if col not in sparse_features and col not in ["income_50k", "marital_stat"] + ] + + self.fields = dense_features + sparse_features + ["label_income", "label_marital"] + self.num_fields = len(self.fields) + self.dense_end = len(dense_features) + self.sparse_end = len(dense_features + sparse_features) def __enter__(self): self.reader = make_batch_reader( @@ -174,18 +256,28 @@ def get_batches(self, reader, batch_size=None): ] ) if len(tail[0]) == batch_size: - label = tail[0] - features = tail[1 : self.num_fields] + dense = tail[0 : self.dense_end] + sparse = tail[self.dense_end : self.sparse_end] + label = tail[self.sparse_end :] tail = None - yield label, np.stack(features, axis=-1) + yield np.stack(label, axis=-1), np.stack(dense, axis=-1), np.stack( + sparse, axis=-1 + ) else: pos = 0 continue while (pos + batch_size) <= len(rglist[0]): - label = rglist[0][pos : pos + batch_size] - features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)] + dense = [rglist[j][pos : pos + batch_size] for j in range(0, self.dense_end)] + sparse = [ + rglist[j][pos : pos + batch_size] + for j in range(self.dense_end, self.sparse_end) + ] + label = [ + rglist[j][pos : pos + batch_size] + for j in range(self.sparse_end, self.num_fields) + ] pos += batch_size - yield label, np.stack(features, axis=-1) + yield np.stack(label, axis=-1), np.stack(dense, axis=-1), np.stack(sparse, axis=-1) if pos != len(rglist[0]): tail = [rglist[i][pos:] for i in range(self.num_fields)] @@ -200,7 +292,7 @@ def make_census_dataloader(data_path, batch_size, shuffle=True): world_size = flow.env.get_world_size() batch_size_per_proc = batch_size // world_size - return MMoeDataReader( + return MmoeDataReader( files, batch_size_per_proc, None, # TODO: iterate over all eval dataset @@ -233,9 +325,7 @@ def __init__( ] if store_type == "device_mem": store_options = flow.one_embedding.make_device_mem_store_options( - persistent_path=persistent_path, - capacity=vocab_size, - size_factor=size_factor, + persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, ) elif store_type == "cached_host_mem": assert cache_memory_budget_mb > 0 @@ -297,53 +387,113 @@ def forward(self, x: flow.Tensor) -> flow.Tensor: return self.linear_layers(x) -class MMoeModule(nn.Module): +class MmoeModule(nn.Module): def __init__( self, - embedding_vec_size=128, - dnn=[1024, 1024, 512, 256], + num_tasks=2, + num_experts=3, + embedding_vec_size=4, + expert_dnn=[256, 128], persistent_path=None, table_size_array=None, one_embedding_store_type="cached_host_mem", cache_memory_budget_mb=8192, - dropout=0.2, ): - super(MMoeModule, self).__init__() + super(MmoeModule, self).__init__() + + self.num_experts = num_experts + self.num_tasks = num_tasks + + self.embedding_layer = OneEmbedding( + table_name="sparse_embedding", + embedding_vec_size=embedding_vec_size, + persistent_path=persistent_path, + table_size_array=table_size_array, + store_type=one_embedding_store_type, + cache_memory_budget_mb=cache_memory_budget_mb, + size_factor=3, + ) + + self.expert = DNN( + in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + hidden_units=expert_dnn[:-1], + out_features=expert_dnn[-1], + skip_final_activation=True, + dropout=0.0, + ) + + self.experts = nn.ModuleList([]) + for _ in range(num_experts): + expert_net = DNN( + in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + hidden_units=expert_dnn[:-1], + out_features=expert_dnn[-1], + skip_final_activation=True, + dropout=0.0, + ) + self.experts.append(expert_net) + + self.gates = nn.ModuleList([]) + self.towers = nn.ModuleList([]) + for _ in range(num_tasks): + gate_net = nn.Linear( + in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + out_features=num_experts, + bias=False, + ) + self.gates.append(gate_net) + + tower_net = nn.Linear(in_features=expert_dnn[-1], out_features=1,) + self.towers.append(tower_net) + + def forward(self, dense_inputs, sparse_inputs) -> flow.Tensor: + sparse_emb = self.embedding_layer(sparse_inputs) + inputs = flow.cat([sparse_emb.flatten(start_dim=1), dense_inputs], dim=1) + # print("inputs: ", inputs.shape) - def forward(self, inputs) -> flow.Tensor: - pass + expert_outs = [] + for expert in self.experts: + expert_outs.append(expert(inputs)) + expert_concat = flow.stack(expert_outs, dim=1) + # print("expert_concat: ", expert_concat.shape) + + mmoe_outs = [] + for i in range(self.num_tasks): + gate_out = self.gates[i](inputs).softmax() + # print("gate: ", gate_out.shape) + gate_out = gate_out.reshape([-1, self.num_experts, 1]) + # print("gate: ", gate_out.shape) + gate_mul_expert = flow.mul(expert_concat, gate_out.expand_as(expert_concat)) + # print("gate_mul_expert: ", gate_mul_expert.shape) + gate_mul_expert = flow.sum(gate_mul_expert, dim=1) + # print("gate_mul_expert: ", gate_mul_expert.shape) + + tower_out = self.towers[i](gate_mul_expert) + # print("tower: ", tower_out.shape) + mmoe_outs.append(tower_out) + + return mmoe_outs def make_mmoe_module(args): - model = MMoeModule( - embedding_vec_size=args.embedding_vec_size, - dnn=args.dnn, + model = MmoeModule( + num_tasks=2, + num_experts=3, + embedding_vec_size=4, + expert_dnn=[256, 128], persistent_path=args.persistent_path, table_size_array=args.table_size_array, one_embedding_store_type=args.store_type, cache_memory_budget_mb=args.cache_memory_budget_mb, - dropout=args.net_dropout, ) return model -class MMoeValGraph(flow.nn.Graph): - def __init__(self, mmoe_module, amp=False): - super(MMoeValGraph, self).__init__() - self.module = mmoe_module - if amp: - self.config.enable_amp(True) - - def build(self, features): - predicts = self.module(features.to("cuda")) - return predicts.sigmoid() - - -class MMoeTrainGraph(flow.nn.Graph): +class MmoeTrainGraph(flow.nn.Graph): def __init__( self, mmoe_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, ): - super(MMoeTrainGraph, self).__init__() + super(MmoeTrainGraph, self).__init__() self.module = mmoe_module self.loss = loss self.add_optimizer(optimizer, lr_sch=lr_scheduler) @@ -354,13 +504,29 @@ def __init__( self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - def build(self, labels, features): - logits = self.module(features.to("cuda")) - loss = self.loss(logits, labels.to("cuda")) + def build(self, labels, dense_fields, sparse_fields): + logits = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) + label_income = labels[:, 0].unsqueeze(1) + label_marital = labels[:, 1].unsqueeze(1) + loss_income = self.loss(logits[0], label_income.to("cuda")) + loss_marital = self.loss(logits[1], label_marital.to("cuda")) + loss = loss_income + loss_marital loss.backward() return loss.to("cpu") +class MmoeValGraph(flow.nn.Graph): + def __init__(self, mmoe_module, amp=False): + super(MmoeValGraph, self).__init__() + self.module = mmoe_module + if amp: + self.config.enable_amp(True) + + def build(self, features): + predicts = self.module(features.to("cuda")) + return predicts.sigmoid() + + def make_lr_scheduler(args, optimizer): batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) milestones = [ @@ -418,7 +584,7 @@ def save_model(subdir): ) eval_graph = MmoeValGraph(mmoe_module, args.amp) - train_graph = MMoeTrainGraph( + train_graph = MmoeTrainGraph( mmoe_module, loss, opt, grad_scaler, args.amp, lr_scheduler=lr_scheduler ) @@ -433,8 +599,11 @@ def save_model(subdir): with make_census_dataloader(f"{args.data_dir}/train", args.batch_size) as loader: step, last_step, last_time = -1, 0, time.time() for step in range(1, args.train_batches + 1): - labels, features = batch_to_global(*next(loader)) - loss = train_graph(labels, features) + labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) + # print("label: ", labels.shape) + # print("dense: ", dense_fields.shape) + # print("sparse: ", sparse_fields.shape) + loss = train_graph(labels, dense_fields, sparse_fields) if step % args.loss_print_interval == 0: loss = loss.numpy() if rank == 0: @@ -447,31 +616,31 @@ def save_model(subdir): + f"Latency {(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" ) - if step % batches_per_epoch == 0: - epoch += 1 - auc = eval( - args, - eval_graph, - cur_step=step, - epoch=epoch, - cached_eval_batches=cached_eval_batches, - ) - if args.save_model_after_each_eval: - save_model(f"step_{step}_val_auc_{auc:0.5f}") - - mmoe_module.train() - last_time = time.time() - - if step % batches_per_epoch != 0: - auc = eval( - args, - eval_graph, - cur_step=step, - epoch=epoch, - cached_eval_batches=cached_eval_batches, - ) - if args.save_model_after_each_eval: - save_model(f"step_{step}_val_auc_{auc:0.5f}") + # if step % batches_per_epoch == 0: + # epoch += 1 + # auc = eval( + # args, + # eval_graph, + # cur_step=step, + # epoch=epoch, + # cached_eval_batches=cached_eval_batches, + # ) + # if args.save_model_after_each_eval: + # save_model(f"step_{step}_val_auc_{auc:0.5f}") + + # mmoe_module.train() + # last_time = time.time() + + # if step % batches_per_epoch != 0: + # auc = eval( + # args, + # eval_graph, + # cur_step=step, + # epoch=epoch, + # cached_eval_batches=cached_eval_batches, + # ) + # if args.save_model_after_each_eval: + # save_model(f"step_{step}_val_auc_{auc:0.5f}") def np_to_global(np): @@ -479,18 +648,20 @@ def np_to_global(np): return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) -def batch_to_global(np_label, np_features, is_train=True): - labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) - features = np_to_global(np_features) - return labels, features +def batch_to_global(np_label, np_dense, np_sparse, is_train=True): + labels = np_to_global(np_label) if is_train else np_label + np_dense = np_to_global(np_dense) + np_sparse = np_to_global(np_sparse) + + return labels, np_dense, np_sparse def prefetch_eval_batches(data_dir, batch_size, num_batches): cached_eval_batches = [] with make_census_dataloader(data_dir, batch_size, shuffle=False) as loader: for _ in range(num_batches): - label, features = batch_to_global(*next(loader), is_train=False) - cached_eval_batches.append((label, features)) + labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) + cached_eval_batches.append((labels, dense_fields, sparse_fields)) return cached_eval_batches From 9e2c614843adbab61624d3092fafd1204890e388 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 8 Jun 2022 15:14:05 +0800 Subject: [PATCH 05/34] Add mmoe eval part; Remove useless code; --- RecommenderSystems/mmoe/mmoe_train_eval.py | 241 ++++++++++----------- 1 file changed, 111 insertions(+), 130 deletions(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index 49c615ddd..09e163fed 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -132,51 +132,6 @@ def __init__( self.shard_count = shard_count self.cur_shard = cur_shard - column_names = [ - "age", - "class_worker", - "det_ind_code", - "det_occ_code", - "education", - "wage_per_hour", - "hs_college", - "marital_stat", - "major_ind_code", - "major_occ_code", - "race", - "hisp_origin", - "sex", - "union_member", - "unemp_reason", - "full_or_part_emp", - "capital_gains", - "capital_losses", - "stock_dividends", - "tax_filer_stat", - "region_prev_res", - "state_prev_res", - "det_hh_fam_stat", - "det_hh_summ", - "instance_weight", - "mig_chg_msa", - "mig_chg_reg", - "mig_move_reg", - "mig_same", - "mig_prev_sunbelt", - "num_emp", - "fam_under_18", - "country_father", - "country_mother", - "country_self", - "citizenship", - "own_or_self", - "vet_question", - "vet_benefits", - "weeks_worked", - "year", - "income_50k", - ] - sparse_features = [ "class_worker", "det_ind_code", @@ -210,15 +165,24 @@ def __init__( ] dense_features = [ - col - for col in column_names - if col not in sparse_features and col not in ["income_50k", "marital_stat"] + "age", + "wage_per_hour", + "capital_gains", + "capital_losses", + "stock_dividends", + "instance_weight", + "num_emp", + "own_or_self", + "vet_benefits", + "weeks_worked", + "year", ] self.fields = dense_features + sparse_features + ["label_income", "label_marital"] - self.num_fields = len(self.fields) + self.dense_end = len(dense_features) self.sparse_end = len(dense_features + sparse_features) + self.num_fields = len(self.fields) def __enter__(self): self.reader = make_batch_reader( @@ -258,9 +222,10 @@ def get_batches(self, reader, batch_size=None): if len(tail[0]) == batch_size: dense = tail[0 : self.dense_end] sparse = tail[self.dense_end : self.sparse_end] - label = tail[self.sparse_end :] + label_income = tail[self.sparse_end] + label_marital = tail[self.sparse_end + 1] tail = None - yield np.stack(label, axis=-1), np.stack(dense, axis=-1), np.stack( + yield label_income, label_marital, np.stack(dense, axis=-1), np.stack( sparse, axis=-1 ) else: @@ -272,12 +237,12 @@ def get_batches(self, reader, batch_size=None): rglist[j][pos : pos + batch_size] for j in range(self.dense_end, self.sparse_end) ] - label = [ - rglist[j][pos : pos + batch_size] - for j in range(self.sparse_end, self.num_fields) - ] + label_income = rglist[self.sparse_end][pos : pos + batch_size] + label_marital = rglist[self.sparse_end + 1][pos : pos + batch_size] pos += batch_size - yield np.stack(label, axis=-1), np.stack(dense, axis=-1), np.stack(sparse, axis=-1) + yield label_income, label_marital, np.stack(dense, axis=-1), np.stack( + sparse, axis=-1 + ) if pos != len(rglist[0]): tail = [rglist[i][pos:] for i in range(self.num_fields)] @@ -414,14 +379,6 @@ def __init__( size_factor=3, ) - self.expert = DNN( - in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, - hidden_units=expert_dnn[:-1], - out_features=expert_dnn[-1], - skip_final_activation=True, - dropout=0.0, - ) - self.experts = nn.ModuleList([]) for _ in range(num_experts): expert_net = DNN( @@ -449,27 +406,21 @@ def __init__( def forward(self, dense_inputs, sparse_inputs) -> flow.Tensor: sparse_emb = self.embedding_layer(sparse_inputs) inputs = flow.cat([sparse_emb.flatten(start_dim=1), dense_inputs], dim=1) - # print("inputs: ", inputs.shape) expert_outs = [] for expert in self.experts: expert_outs.append(expert(inputs)) expert_concat = flow.stack(expert_outs, dim=1) - # print("expert_concat: ", expert_concat.shape) mmoe_outs = [] for i in range(self.num_tasks): gate_out = self.gates[i](inputs).softmax() - # print("gate: ", gate_out.shape) gate_out = gate_out.reshape([-1, self.num_experts, 1]) - # print("gate: ", gate_out.shape) + gate_mul_expert = flow.mul(expert_concat, gate_out.expand_as(expert_concat)) - # print("gate_mul_expert: ", gate_mul_expert.shape) gate_mul_expert = flow.sum(gate_mul_expert, dim=1) - # print("gate_mul_expert: ", gate_mul_expert.shape) tower_out = self.towers[i](gate_mul_expert) - # print("tower: ", tower_out.shape) mmoe_outs.append(tower_out) return mmoe_outs @@ -504,15 +455,13 @@ def __init__( self.config.enable_amp(True) self.set_grad_scaler(grad_scaler) - def build(self, labels, dense_fields, sparse_fields): + def build(self, label_income, label_marital, dense_fields, sparse_fields): logits = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) - label_income = labels[:, 0].unsqueeze(1) - label_marital = labels[:, 1].unsqueeze(1) loss_income = self.loss(logits[0], label_income.to("cuda")) loss_marital = self.loss(logits[1], label_marital.to("cuda")) loss = loss_income + loss_marital loss.backward() - return loss.to("cpu") + return loss_income.to("cpu"), loss_marital.to("cpu") class MmoeValGraph(flow.nn.Graph): @@ -522,9 +471,9 @@ def __init__(self, mmoe_module, amp=False): if amp: self.config.enable_amp(True) - def build(self, features): - predicts = self.module(features.to("cuda")) - return predicts.sigmoid() + def build(self, dense_fields, sparse_fields): + preds = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) + return preds[0].sigmoid(), preds[1].sigmoid() def make_lr_scheduler(args, optimizer): @@ -599,48 +548,48 @@ def save_model(subdir): with make_census_dataloader(f"{args.data_dir}/train", args.batch_size) as loader: step, last_step, last_time = -1, 0, time.time() for step in range(1, args.train_batches + 1): - labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) - # print("label: ", labels.shape) - # print("dense: ", dense_fields.shape) - # print("sparse: ", sparse_fields.shape) - loss = train_graph(labels, dense_fields, sparse_fields) + label_income, label_marital, dense_fields, sparse_fields = batch_to_global( + *next(loader) + ) + loss_income, loss_marital = train_graph( + label_income, label_marital, dense_fields, sparse_fields + ) if step % args.loss_print_interval == 0: - loss = loss.numpy() + loss_income = loss_income.numpy() + loss_marital = loss_marital.numpy() if rank == 0: latency = (time.time() - last_time) / (step - last_step) throughput = args.batch_size / latency last_step, last_time = step, time.time() strtime = time.strftime("%Y-%m-%d %H:%M:%S") print( - f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, " + f"Rank[{rank}], Step {step}, Loss_income {loss_income:0.4f}, Loss_marital {loss_marital:0.4f}, " + f"Latency {(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" ) - # if step % batches_per_epoch == 0: - # epoch += 1 - # auc = eval( - # args, - # eval_graph, - # cur_step=step, - # epoch=epoch, - # cached_eval_batches=cached_eval_batches, - # ) - # if args.save_model_after_each_eval: - # save_model(f"step_{step}_val_auc_{auc:0.5f}") - - # mmoe_module.train() - # last_time = time.time() - - # if step % batches_per_epoch != 0: - # auc = eval( - # args, - # eval_graph, - # cur_step=step, - # epoch=epoch, - # cached_eval_batches=cached_eval_batches, - # ) - # if args.save_model_after_each_eval: - # save_model(f"step_{step}_val_auc_{auc:0.5f}") + if step % batches_per_epoch == 0: + epoch += 1 + auc_income, auc_marital = eval( + args, + eval_graph, + cur_step=step, + epoch=epoch, + cached_eval_batches=cached_eval_batches, + ) + if args.save_model_after_each_eval: + save_model( + f"step_{step}_val_auc_income_{auc_income:0.5f}_marital_{auc_marital:0.5f}" + ) + + mmoe_module.train() + last_time = time.time() + + if step % batches_per_epoch != 0: + auc_income, auc_marital = eval( + args, eval_graph, cur_step=step, epoch=epoch, cached_eval_batches=cached_eval_batches, + ) + if args.save_model_after_each_eval: + save_model(f"step_{step}_val_auc_income_{auc_income:0.5f}_marital_{auc_marital:0.5f}") def np_to_global(np): @@ -648,20 +597,29 @@ def np_to_global(np): return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) -def batch_to_global(np_label, np_dense, np_sparse, is_train=True): - labels = np_to_global(np_label) if is_train else np_label +def batch_to_global(np_label_income, np_label_marital, np_dense, np_sparse, is_train=True): + label_income = ( + np_to_global(np_label_income.reshape(-1, 1)) if is_train else np_label_income.reshape(-1, 1) + ) + label_marital = ( + np_to_global(np_label_marital.reshape(-1, 1)) + if is_train + else np_label_marital.reshape(-1, 1) + ) np_dense = np_to_global(np_dense) np_sparse = np_to_global(np_sparse) - return labels, np_dense, np_sparse + return label_income, label_marital, np_dense, np_sparse def prefetch_eval_batches(data_dir, batch_size, num_batches): cached_eval_batches = [] with make_census_dataloader(data_dir, batch_size, shuffle=False) as loader: for _ in range(num_batches): - labels, dense_fields, sparse_fields = batch_to_global(*next(loader)) - cached_eval_batches.append((labels, dense_fields, sparse_fields)) + label_income, label_marital, dense_fields, sparse_fields = batch_to_global( + *next(loader) + ) + cached_eval_batches.append((label_income, label_marital, dense_fields, sparse_fields)) return cached_eval_batches @@ -669,20 +627,36 @@ def eval(args, eval_graph, cur_step=0, epoch=0, cached_eval_batches=None): batches_per_epoch = math.ceil(args.num_test_samples / args.batch_size) eval_graph.module.eval() - labels, preds = [], [] + label_incomes, label_maritals = [], [] + pred_incomes, pred_maritals = [], [] eval_start_time = time.time() for i in range(batches_per_epoch): - label, features = cached_eval_batches[i] - pred = eval_graph(features) - labels.append(label) - preds.append(pred.to_local()) - - labels = ( - np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() + label_income, label_marital, dense_fields, sparse_fields = cached_eval_batches[i] + pred_income, pred_marital = eval_graph(dense_fields, sparse_fields) + label_incomes.append(label_income) + label_maritals.append(label_marital) + pred_incomes.append(pred_income.to_local()) + pred_maritals.append(pred_marital.to_local()) + + label_incomes = ( + np_to_global(np.concatenate(label_incomes, axis=0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() + ) + label_maritals = ( + np_to_global(np.concatenate(label_maritals, axis=0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() + ) + pred_incomes = ( + flow.cat(pred_incomes, dim=0) + .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() ) - preds = ( - flow.cat(preds, dim=0) + pred_maritals = ( + flow.cat(pred_maritals, dim=0) .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) .to_global(sbp=flow.sbp.broadcast()) .to_local() @@ -694,8 +668,14 @@ def eval(args, eval_graph, cur_step=0, epoch=0, cached_eval_batches=None): rank = flow.env.get_rank() metrics_start_time = time.time() - auc = flow.roc_auc_score(labels, preds).numpy()[0] - logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean") + auc_income = flow.roc_auc_score(label_incomes, pred_incomes).numpy()[0] + auc_marital = flow.roc_auc_score(label_maritals, pred_maritals).numpy()[0] + loss_income = flow._C.binary_cross_entropy_loss( + pred_incomes, label_incomes, weight=None, reduction="mean" + ) + loss_marital = flow._C.binary_cross_entropy_loss( + pred_maritals, label_maritals, weight=None, reduction="mean" + ) metrics_time = time.time() - metrics_start_time if rank == 0: @@ -705,12 +685,13 @@ def eval(args, eval_graph, cur_step=0, epoch=0, cached_eval_batches=None): strtime = time.strftime("%Y-%m-%d %H:%M:%S") print( - f"Rank[{rank}], Epoch {epoch}, Step {cur_step}, AUC {auc:0.6f}, LogLoss {logloss:0.6f}, " - + f"Eval_time {eval_time:0.2f} s, Metrics_time {metrics_time:0.2f} s, Eval_samples {labels.shape[0]}, " + f"Rank[{rank}], Epoch {epoch}, Step {cur_step}, AUC_income {auc_income:0.6f}, AUC_marital {auc_marital:0.6f}, " + + f"Loss_income {loss_income:0.6f}, Loss_marital {loss_marital:0.6f}, " + + f"Eval_time {eval_time:0.2f} s, Metrics_time {metrics_time:0.2f} s, Eval_samples {label_incomes.shape[0]}, " + f"GPU_Memory {device_mem_str}, Host_Memory {host_mem_mb} MiB, {strtime}" ) - return auc + return auc_income, auc_marital if __name__ == "__main__": From 21df60f6a5b829f3ab683d4b7063b249793277ce Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 8 Jun 2022 15:30:50 +0800 Subject: [PATCH 06/34] Update args --- RecommenderSystems/mmoe/mmoe_train_eval.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index 09e163fed..803d068b8 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -41,9 +41,11 @@ def str_list(x): help="save model after each eval or not", ) - parser.add_argument("--embedding_vec_size", type=int, default=16, help="embedding vector size") + parser.add_argument("--num_experts", type=int, default=3, help="the number of experts") + parser.add_argument("--num_tasks", type=int, default=2, help="the number of tasks") + parser.add_argument("--embedding_vec_size", type=int, default=4, help="embedding vector size") parser.add_argument( - "--dnn", type=int_list, default="1000,1000,1000,1000,1000", help="dnn hidden units number" + "--expert_dnn", type=int_list, default="256, 128", help="dnn hidden units number" ) parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate") @@ -52,10 +54,10 @@ def str_list(x): parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") parser.add_argument( - "--batch_size", type=int, default=10000, help="training/evaluation batch size" + "--batch_size", type=int, default=256, help="training/evaluation batch size" ) parser.add_argument( - "--train_batches", type=int, default=75000, help="the maximum number of training batches" + "--train_batches", type=int, default=16000, help="the maximum number of training batches" ) parser.add_argument("--loss_print_interval", type=int, default=100, help="") @@ -428,10 +430,10 @@ def forward(self, dense_inputs, sparse_inputs) -> flow.Tensor: def make_mmoe_module(args): model = MmoeModule( - num_tasks=2, - num_experts=3, - embedding_vec_size=4, - expert_dnn=[256, 128], + num_tasks=args.num_tasks, + num_experts=args.num_experts, + embedding_vec_size=args.embedding_vec_size, + expert_dnn=args.expert_dnn, persistent_path=args.persistent_path, table_size_array=args.table_size_array, one_embedding_store_type=args.store_type, From e955c0884148184c378ad4a8d34114b0352ad447 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 8 Jun 2022 15:31:20 +0800 Subject: [PATCH 07/34] Add sh script --- RecommenderSystems/mmoe/train_mmoe.sh | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 RecommenderSystems/mmoe/train_mmoe.sh diff --git a/RecommenderSystems/mmoe/train_mmoe.sh b/RecommenderSystems/mmoe/train_mmoe.sh new file mode 100644 index 000000000..1b7411211 --- /dev/null +++ b/RecommenderSystems/mmoe/train_mmoe.sh @@ -0,0 +1,26 @@ +#!/bin/bash +DEVICE_NUM_PER_NODE=1 +DATA_DIR=/path/to/mmoe_parquet +PERSISTENT_PATH=/path/to/persistent +MODEL_SAVE_DIR=/path/to/model/save/dir + +python3 -m oneflow.distributed.launch \ + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + mmoe_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "8, 38, 36, 16, 3, 21, 14, 5, 8, 2, 3, 5, 7, 6, 5, 16, 13, 7, 6, 8, 8, 3, 4, 4, 15, 16, 14, 5, 3" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 256 \ + --train_batches 16000 \ + --loss_print_interval 100 \ + --learning_rate 0.001 \ + --embedding_vec_size 4 \ + --expert_dnn "256, 128" \ + --num_train_samples 199523 \ + --num_test_samples 99762 \ + --model_save_dir $MODEL_SAVE_DIR From 966b03dd3ee6f72f687ef1ef11ecb2e0d1937e7f Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 9 Jun 2022 14:56:19 +0800 Subject: [PATCH 08/34] Fix bugs in parallel --- RecommenderSystems/mmoe/mmoe_train_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index 803d068b8..3ac9f2b55 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -619,7 +619,7 @@ def prefetch_eval_batches(data_dir, batch_size, num_batches): with make_census_dataloader(data_dir, batch_size, shuffle=False) as loader: for _ in range(num_batches): label_income, label_marital, dense_fields, sparse_fields = batch_to_global( - *next(loader) + *next(loader), is_train=False ) cached_eval_batches.append((label_income, label_marital, dense_fields, sparse_fields)) return cached_eval_batches From 503780a7c117f16e29895a172b6b24bf9f2068f8 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 9 Jun 2022 14:57:14 +0800 Subject: [PATCH 09/34] Replace table size array; --- RecommenderSystems/mmoe/train_mmoe.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RecommenderSystems/mmoe/train_mmoe.sh b/RecommenderSystems/mmoe/train_mmoe.sh index 1b7411211..a556784b1 100644 --- a/RecommenderSystems/mmoe/train_mmoe.sh +++ b/RecommenderSystems/mmoe/train_mmoe.sh @@ -12,7 +12,7 @@ python3 -m oneflow.distributed.launch \ mmoe_train_eval.py \ --data_dir $DATA_DIR \ --persistent_path $PERSISTENT_PATH \ - --table_size_array "8, 38, 36, 16, 3, 21, 14, 5, 8, 2, 3, 5, 7, 6, 5, 16, 13, 7, 6, 8, 8, 3, 4, 4, 15, 16, 14, 5, 3" \ + --table_size_array "9, 52, 47, 17, 3, 24, 15, 5, 10, 2, 3, 6, 8, 6, 6, 51, 38, 8, 10, 9, 10, 3, 4, 5, 43, 43, 43, 5, 3" \ --store_type 'cached_host_mem' \ --cache_memory_budget_mb 1024 \ --batch_size 256 \ From 0875be5b1c04881c0936a575d53e401ca57f76e2 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 9 Jun 2022 15:23:07 +0800 Subject: [PATCH 10/34] Update readme; Update args; --- RecommenderSystems/mmoe/README.md | 157 +++++++-------------- RecommenderSystems/mmoe/mmoe_train_eval.py | 6 +- 2 files changed, 58 insertions(+), 105 deletions(-) diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md index b68682dc4..0408424d1 100644 --- a/RecommenderSystems/mmoe/README.md +++ b/RecommenderSystems/mmoe/README.md @@ -1,48 +1,47 @@ -# MMoe +# MMoE [Multi-gate Mixture-of-Experts (MMoE)](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) adapts the Mixture-of- Experts (MoE) structure to multi-task learning by sharing the expert submodels across all tasks, while also having a gating network trained to optimize each task. Its model structure is as follows. Based on this structure, this project uses OneFlow distributed deep learning framework to realize training the model in graph mode on the Criteo data set.

Screen Shot 2022-04-01 at 4 45 22 PM

- - ## Directory description ```txt - +. +├── mmoe_train_eval.py # OneFlow DeepFM train/val/test scripts with OneEmbedding module +├── README.md # Documentation +├── tools +│ ├── mmoe_parquet.py # Read census-income data and export it as parquet data format +└── train_mmoe.sh # MMoE training shell script ``` ## Arguments description -| Argument Name | Argument Explanation | Default Value | -| -------------------------- | ------------------------------------------------------------ | ------------------------ | -| data_dir | the data file directory | *Required Argument* | -| num_train_samples | the number of train samples | *Required Argument* | -| num_val_samples | the number of validation samples | *Required Argument* | -| num_test_samples | the number of test samples | *Required Argument* | -| model_load_dir | model loading directory | None | -| model_save_dir | model saving directory | None | -| save_best_model | save best model or not | False | -| save_initial_model | save initial model parameters or not | False | -| save_model_after_each_eval | save model after each eval or not | False | -| embedding_vec_size | embedding vector size | 16 | -| dnn | dnn hidden units number | 1000,1000,1000,1000,1000 | -| net_dropout | number of minibatch training interations | 0.2 | -| embedding_vec_size | embedding vector size | 16 | -| learning_rate | initial learning rate | 0.001 | -| batch_size | training/evaluation batch size | 10000 | -| train_batches | the maximum number of training batches | 75000 | -| loss_print_interval | interval of printing loss | 100 | -| patience | Number of epochs with no improvement after which learning rate will be reduced | 2 | -| min_delta | threshold for measuring the new optimum, to only focus on significant changes | 1.0e-6 | -| table_size_array | embedding table size array for sparse fields | *Required Argument* | -| persistent_path | path for persistent kv store of embedding | *Required Argument* | -| store_type | OneEmbeddig persistent kv store type: `device_mem`, `cached_host_mem` or `cached_ssd` | `cached_host_mem` | -| cache_memory_budget_mb | size of cache memory budget on each device in megabytes when `store_type` is `cached_host_mem` or `cached_ssd` | 1024 | -| amp | enable Automatic Mixed Precision(AMP) training or not | False | -| loss_scale_policy | loss scale policy for AMP training: `static` or `dynamic` | `static` | -| disable_early_stop | disable early stop or not | False | +| Argument Name | Argument Explanation | Default Value | +| -------------------------- | ------------------------------------------------------------ | ------------------- | +| data_dir | the data file directory | *Required Argument* | +| num_train_samples | the number of train samples | *Required Argument* | +| num_test_samples | the number of test samples | *Required Argument* | +| model_load_dir | model loading directory | None | +| model_save_dir | model saving directory | None | +| save_initial_model | save initial model parameters or not | False | +| save_model_after_each_eval | save model after each eval or not | False | +| num_experts | the number of experts | 3 | +| num_tasks | the number of tasks | 2 | +| embedding_vec_size | embedding vector size | 16 | +| expert_dnn | expert dnn hidden units number | 256, 128 | +| net_dropout | net dropout rate | 0.0 | +| learning_rate | initial learning rate | 0.001 | +| batch_size | training/evaluation batch size | 256 | +| train_batches | the maximum number of training batches | 16000 | +| loss_print_interval | interval of printing loss | 100 | +| table_size_array | embedding table size array for sparse fields | *Required Argument* | +| persistent_path | path for persistent kv store of embedding | *Required Argument* | +| store_type | OneEmbeddig persistent kv store type: `device_mem`, `cached_host_mem` or `cached_ssd` | `cached_host_mem` | +| cache_memory_budget_mb | size of cache memory budget on each device in megabytes when `store_type` is `cached_host_mem` or `cached_ssd` | 1024 | +| amp | enable Automatic Mixed Precision(AMP) training or not | False | +| loss_scale_policy | loss scale policy for AMP training: `static` or `dynamic` | `static` | ## Getting Started @@ -68,85 +67,37 @@ A hands-on guide to train a MMoe model. ### Dataset -**Note**: - -According to [the DeepFM paper](https://arxiv.org/abs/1703.04247), we treat both categorical and continuous features as sparse features. - -> χ may include categorical fields (e.g., gender, location) and continuous fields (e.g., age). Each categorical field is represented as a vec- tor of one-hot encoding, and each continuous field is repre- sented as the value itself, or a vector of one-hot encoding after discretization. - -1. Download the [Criteo Kaggle dataset](https://www.kaggle.com/c/criteo-display-ad-challenge) and then split it using [split_criteo_kaggle.py](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/split_criteo_kaggle.py). - - Note: Same as [the DeepFM_Criteo_x4_001 experiment](https://github.com/openbenchmark/BARS/tree/master/ctr_prediction/benchmarks/DeepFM/DeepFM_criteo_x4_001) in FuxiCTR, only train.txt is used. Also, the dataset is randomly spllitted into 8:1:1 as training set, validation set and test set. The dataset is splitted using StratifiedKFold in sklearn. - - ```shell - python3 split_criteo_kaggle.py --input_dir=/path/to/your/criteo_kaggle --output_dir=/path/to/your/output/dir - ``` - -2. Download spark from https://spark.apache.org/downloads.html and then uncompress the tar file into the directory where you want to install Spark. Ensure the `SPARK_HOME` environment variable points to the directory where the spark is. - -3. launch a spark shell using [launch_spark.sh](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/launch_spark.sh). - - - Modify the SPARK_LOCAL_DIRS as needed - - ```shell - export SPARK_LOCAL_DIRS=/path/to/your/spark/ - ``` - - - Run `bash launch_spark.sh` - -4. load [deepfm_parquet.scala](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/tools/deepfm_parquet.scala) to your spark shell by `:load deepfm_parquet.scala`. - -5. call the `makeDeepfmDataset(srcDir: String, dstDir:String)` function to generate the dataset. - - ```shell - makeDeepfmDataset("/path/to/your/src_dir", "/path/to/your/dst_dir") - ``` - - After generating parquet dataset, dataset information will also be printed. It contains the information about the number of samples and table size array, which is needed when training. - - ```txt - train samples = 36672493 - validation samples = 4584062 - test samples = 4584062 - table size array: - 649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376 - 1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572 - ``` - ### Start Training by Oneflow -1. Modify the [train_deepfm.sh](https://github.com/Oneflow-Inc/models/blob/dev_deepfm_multicol_oneemb/RecommenderSystems/deepfm/train_deepfm.sh) as needed. +1. Modify the **train_mmoe.sh** as needed. ```shell #!/bin/bash DEVICE_NUM_PER_NODE=1 - DATA_DIR=/path/to/deepfm_parquet + DATA_DIR=/path/to/mmoe_parquet PERSISTENT_PATH=/path/to/persistent MODEL_SAVE_DIR=/path/to/model/save/dir python3 -m oneflow.distributed.launch \ - --nproc_per_node $DEVICE_NUM_PER_NODE \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr 127.0.0.1 \ - deepfm_train_eval.py \ - --data_dir $DATA_DIR \ - --persistent_path $PERSISTENT_PATH \ - --table_size_array "649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572" \ - --store_type 'cached_host_mem' \ - --cache_memory_budget_mb 1024 \ - --batch_size 10000 \ - --train_batches 75000 \ - --loss_print_interval 100 \ - --dnn "1000,1000,1000,1000,1000" \ - --net_dropout 0.2 \ - --learning_rate 0.001 \ - --embedding_vec_size 16 \ - --num_train_samples 36672493 \ - --num_val_samples 4584062 \ - --num_test_samples 4584062 \ - --model_save_dir $MODEL_SAVE_DIR \ - --save_best_model + --nproc_per_node $DEVICE_NUM_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + mmoe_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "9, 52, 47, 17, 3, 24, 15, 5, 10, 2, 3, 6, 8, 6, 6, 51, 38, 8, 10, 9, 10, 3, 4, 5, 43, 43, 43, 5, 3" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 256 \ + --train_batches 16000 \ + --loss_print_interval 100 \ + --learning_rate 0.001 \ + --embedding_vec_size 4 \ + --expert_dnn "256, 128" \ + --num_train_samples 199523 \ + --num_test_samples 99762 \ + --model_save_dir $MODEL_SAVE_DIR ``` - -2. train a DeepFM model by `bash train_deepfm.sh`. + +2. train a MMoE model by `bash train_mmoe.sh`. \ No newline at end of file diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index 3ac9f2b55..e5d405368 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -47,7 +47,7 @@ def str_list(x): parser.add_argument( "--expert_dnn", type=int_list, default="256, 128", help="dnn hidden units number" ) - parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate") + parser.add_argument("--net_dropout", type=float, default=0.0, help="net dropout rate") parser.add_argument("--lr_factor", type=float, default=0.1) parser.add_argument("--min_lr", type=float, default=1.0e-6) @@ -361,6 +361,7 @@ def __init__( num_experts=3, embedding_vec_size=4, expert_dnn=[256, 128], + net_dropout=0.0, persistent_path=None, table_size_array=None, one_embedding_store_type="cached_host_mem", @@ -388,7 +389,7 @@ def __init__( hidden_units=expert_dnn[:-1], out_features=expert_dnn[-1], skip_final_activation=True, - dropout=0.0, + dropout=net_dropout, ) self.experts.append(expert_net) @@ -434,6 +435,7 @@ def make_mmoe_module(args): num_experts=args.num_experts, embedding_vec_size=args.embedding_vec_size, expert_dnn=args.expert_dnn, + net_dropout=args.net_dropout, persistent_path=args.persistent_path, table_size_array=args.table_size_array, one_embedding_store_type=args.store_type, From ac2cd7ebf7389c3256e75b2895fc6abfda61fa7f Mon Sep 17 00:00:00 2001 From: Xinman Liu Date: Thu, 9 Jun 2022 15:28:50 +0800 Subject: [PATCH 11/34] Update README.md --- RecommenderSystems/mmoe/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md index 0408424d1..241975612 100644 --- a/RecommenderSystems/mmoe/README.md +++ b/RecommenderSystems/mmoe/README.md @@ -1,10 +1,10 @@ # MMoE [Multi-gate Mixture-of-Experts (MMoE)](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) adapts the Mixture-of- Experts (MoE) structure to multi-task learning by sharing the expert submodels across all tasks, while also having a gating network trained to optimize each task. Its model structure is as follows. Based on this structure, this project uses OneFlow distributed deep learning framework to realize training the model in graph mode on the Criteo data set. -

- Screen Shot 2022-04-01 at 4 45 22 PM + mmoe

+ ## Directory description ```txt @@ -100,4 +100,4 @@ A hands-on guide to train a MMoe model. --model_save_dir $MODEL_SAVE_DIR ``` -2. train a MMoE model by `bash train_mmoe.sh`. \ No newline at end of file +2. train a MMoE model by `bash train_mmoe.sh`. From 315fe7de94bf20056deab2e4f4cd2426fb486356 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 15 Jun 2022 16:49:48 +0800 Subject: [PATCH 12/34] Change gate and tower to dnn --- RecommenderSystems/mmoe/README.md | 4 +- RecommenderSystems/mmoe/mmoe_train_eval.py | 43 +++++++++++++++++----- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md index 241975612..ffd040b97 100644 --- a/RecommenderSystems/mmoe/README.md +++ b/RecommenderSystems/mmoe/README.md @@ -30,7 +30,9 @@ | num_experts | the number of experts | 3 | | num_tasks | the number of tasks | 2 | | embedding_vec_size | embedding vector size | 16 | -| expert_dnn | expert dnn hidden units number | 256, 128 | +| expert_dnn | expert dnn hidden units number | [256, 128] | +| gate_dnn | gate dnn hidden units number | [] | +| tower_dnn | tower dnn hidden units number | [] | | net_dropout | net dropout rate | 0.0 | | learning_rate | initial learning rate | 0.001 | | batch_size | training/evaluation batch size | 256 | diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index e5d405368..dcd0e65f7 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -15,6 +15,8 @@ def get_args(print_args=True): def int_list(x): + if x == "": + return [] return list(map(int, x.split(","))) def str_list(x): @@ -45,7 +47,11 @@ def str_list(x): parser.add_argument("--num_tasks", type=int, default=2, help="the number of tasks") parser.add_argument("--embedding_vec_size", type=int, default=4, help="embedding vector size") parser.add_argument( - "--expert_dnn", type=int_list, default="256, 128", help="dnn hidden units number" + "--expert_dnn", type=int_list, default="256, 128", help="expert dnn hidden units number" + ) + parser.add_argument("--gate_dnn", type=int_list, default="", help="gate hidden units number") + parser.add_argument( + "--tower_dnn", type=int_list, default="", help="tower dnn hidden units number" ) parser.add_argument("--net_dropout", type=float, default=0.0, help="net dropout rate") @@ -329,15 +335,22 @@ def forward(self, ids): class DNN(nn.Module): def __init__( - self, in_features, hidden_units, out_features, skip_final_activation=False, dropout=0.0 + self, + in_features, + hidden_units, + out_features, + skip_final_activation=False, + dropout=0.0, + use_final_bias=True, ) -> None: super(DNN, self).__init__() denses = [] dropout_rates = [dropout] * len(hidden_units) + [0.0] use_relu = [True] * len(hidden_units) + [not skip_final_activation] + use_bias = [True] * len(hidden_units) + [use_final_bias] hidden_units = [in_features] + hidden_units + [out_features] for idx in range(len(hidden_units) - 1): - denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)) + denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=use_bias[idx])) if use_relu[idx]: denses.append(nn.ReLU()) if dropout_rates[idx] > 0: @@ -361,6 +374,8 @@ def __init__( num_experts=3, embedding_vec_size=4, expert_dnn=[256, 128], + gate_dnn=[], + tower_dnn=[], net_dropout=0.0, persistent_path=None, table_size_array=None, @@ -396,14 +411,24 @@ def __init__( self.gates = nn.ModuleList([]) self.towers = nn.ModuleList([]) for _ in range(num_tasks): - gate_net = nn.Linear( + gate_net = DNN( in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + hidden_units=gate_dnn, out_features=num_experts, - bias=False, + skip_final_activation=True, + dropout=net_dropout, + use_final_bias=False, ) self.gates.append(gate_net) - tower_net = nn.Linear(in_features=expert_dnn[-1], out_features=1,) + tower_net = DNN( + in_features=expert_dnn[-1], + hidden_units=tower_dnn, + out_features=1, + skip_final_activation=True, + dropout=net_dropout, + use_final_bias=False, + ) self.towers.append(tower_net) def forward(self, dense_inputs, sparse_inputs) -> flow.Tensor: @@ -435,6 +460,8 @@ def make_mmoe_module(args): num_experts=args.num_experts, embedding_vec_size=args.embedding_vec_size, expert_dnn=args.expert_dnn, + gate_dnn=args.gate_dnn, + tower_dnn=args.tower_dnn, net_dropout=args.net_dropout, persistent_path=args.persistent_path, table_size_array=args.table_size_array, @@ -537,9 +564,7 @@ def save_model(subdir): ) eval_graph = MmoeValGraph(mmoe_module, args.amp) - train_graph = MmoeTrainGraph( - mmoe_module, loss, opt, grad_scaler, args.amp, lr_scheduler=lr_scheduler - ) + train_graph = MmoeTrainGraph(mmoe_module, loss, opt, grad_scaler, args.amp, lr_scheduler=None) batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) From 4c82e90e716c05cdfea6ee47c5c293965c9e917b Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 16 Jun 2022 13:10:11 +0800 Subject: [PATCH 13/34] add oneembedding key_type --- RecommenderSystems/dlrm/dlrm_train_eval.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/RecommenderSystems/dlrm/dlrm_train_eval.py b/RecommenderSystems/dlrm/dlrm_train_eval.py index 1f5bff94a..e70b0a33c 100644 --- a/RecommenderSystems/dlrm/dlrm_train_eval.py +++ b/RecommenderSystems/dlrm/dlrm_train_eval.py @@ -41,6 +41,7 @@ def str_list(x): parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") parser.add_argument("--embedding_vec_size", type=int, default=128) + parser.add_argument("--one_embedding_key_type", type=str, default="int64", help="OneEmbedding key type: int32, int64") parser.add_argument("--bottom_mlp", type=int_list, default="512,256,128") parser.add_argument("--top_mlp", type=int_list, default="1024,1024,512,256") parser.add_argument( @@ -266,6 +267,7 @@ def __init__( table_size_array, store_type, cache_memory_budget_mb, + key_type, ): assert table_size_array is not None vocab_size = sum(table_size_array) @@ -303,7 +305,7 @@ def __init__( "sparse_embedding", embedding_dim=embedding_vec_size, dtype=flow.float, - key_type=flow.int64, + key_type=getattr(flow, key_type), tables=tables, store_options=store_options, ) @@ -322,6 +324,7 @@ def __init__( persistent_path=None, table_size_array=None, one_embedding_store_type="cached_host_mem", + one_embedding_key_type="int64", cache_memory_budget_mb=8192, interaction_itself=True, interaction_padding=True, @@ -337,6 +340,7 @@ def __init__( table_size_array, one_embedding_store_type, cache_memory_budget_mb, + one_embedding_key_type, ) self.interaction = Interaction( bottom_mlp[-1], @@ -368,6 +372,7 @@ def make_dlrm_module(args): persistent_path=args.persistent_path, table_size_array=args.table_size_array, one_embedding_store_type=args.store_type, + one_embedding_key_type=args.one_embedding_key_type, cache_memory_budget_mb=args.cache_memory_budget_mb, interaction_itself=args.interaction_itself, interaction_padding=not args.disable_interaction_padding, From f44b50fc0406df16335830dca6151db6c52f8de3 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 16 Jun 2022 13:50:36 +0800 Subject: [PATCH 14/34] pad dense input --- RecommenderSystems/dlrm/dlrm_train_eval.py | 27 ++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_train_eval.py b/RecommenderSystems/dlrm/dlrm_train_eval.py index e70b0a33c..0be0aaf9c 100644 --- a/RecommenderSystems/dlrm/dlrm_train_eval.py +++ b/RecommenderSystems/dlrm/dlrm_train_eval.py @@ -41,7 +41,12 @@ def str_list(x): parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") parser.add_argument("--embedding_vec_size", type=int, default=128) - parser.add_argument("--one_embedding_key_type", type=str, default="int64", help="OneEmbedding key type: int32, int64") + parser.add_argument( + "--one_embedding_key_type", + type=str, + default="int64", + help="OneEmbedding key type: int32, int64", + ) parser.add_argument("--bottom_mlp", type=int_list, default="512,256,128") parser.add_argument("--top_mlp", type=int_list, default="1024,1024,512,256") parser.add_argument( @@ -49,6 +54,11 @@ def str_list(x): action="store_true", help="disable interaction padding or not", ) + parser.add_argument( + "--disable_dense_input_padding", + action="store_true", + help="disable dense input padding or not", + ) parser.add_argument( "--interaction_itself", action="store_true", help="interaction itself or not" ) @@ -328,12 +338,22 @@ def __init__( cache_memory_budget_mb=8192, interaction_itself=True, interaction_padding=True, + dense_input_padding=True, ): super(DLRMModule, self).__init__() assert ( embedding_vec_size == bottom_mlp[-1] ), "Embedding vector size must equle to bottom MLP output size" - self.bottom_mlp = MLP(num_dense_fields, bottom_mlp, fused=use_fusedmlp) + self.num_dense_fields = ( + ((num_dense_fields + 8 - 1) // 8 * 8) if dense_input_padding else num_dense_fields + ) + self.pad = ( + [0, self.num_dense_fields - num_dense_fields] + if self.num_dense_fields > num_dense_fields + else None + ) + + self.bottom_mlp = MLP(self.num_dense_fields, bottom_mlp, fused=use_fusedmlp) self.embedding = OneEmbedding( embedding_vec_size, persistent_path, @@ -356,6 +376,8 @@ def __init__( ) def forward(self, dense_fields, sparse_fields) -> flow.Tensor: + if self.pad: + dense_fields = flow.nn.functional.pad(dense_fields, self.pad, "constant") dense_fields = flow.log(dense_fields + 1.0) dense_fields = self.bottom_mlp(dense_fields) embedding = self.embedding(sparse_fields) @@ -376,6 +398,7 @@ def make_dlrm_module(args): cache_memory_budget_mb=args.cache_memory_budget_mb, interaction_itself=args.interaction_itself, interaction_padding=not args.disable_interaction_padding, + dense_input_padding=not args.disable_dense_input_padding, ) return model From f2dbdf1f80b0d5aecb2def3ef7fccb8c08f89465 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 16 Jun 2022 14:06:55 +0800 Subject: [PATCH 15/34] add padding in prefetch --- .../dlrm/dlrm_prefetch_train.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index 88ef64963..3f8f63ae5 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -41,6 +41,12 @@ def str_list(x): parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") parser.add_argument("--embedding_vec_size", type=int, default=128) + parser.add_argument( + "--one_embedding_key_type", + type=str, + default="int64", + help="OneEmbedding key type: int32, int64", + ) parser.add_argument("--bottom_mlp", type=int_list, default="512,256,128") parser.add_argument("--top_mlp", type=int_list, default="1024,1024,512,256") parser.add_argument( @@ -48,6 +54,11 @@ def str_list(x): action="store_true", help="disable interaction padding or not", ) + parser.add_argument( + "--disable_dense_input_padding", + action="store_true", + help="disable dense input padding or not", + ) parser.add_argument( "--interaction_itself", action="store_true", help="interaction itself or not" ) @@ -266,6 +277,7 @@ def __init__( table_size_array, store_type, cache_memory_budget_mb, + key_type, ): assert table_size_array is not None vocab_size = sum(table_size_array) @@ -303,7 +315,7 @@ def __init__( "sparse_embedding", embedding_dim=embedding_vec_size, dtype=flow.float, - key_type=flow.int64, + key_type=getattr(flow, key_type), tables=tables, store_options=store_options, ) @@ -322,21 +334,33 @@ def __init__( persistent_path=None, table_size_array=None, one_embedding_store_type="cached_host_mem", + one_embedding_key_type="int64", cache_memory_budget_mb=8192, interaction_itself=True, interaction_padding=True, + dense_input_padding=True, ): super(DLRMModule, self).__init__() assert ( embedding_vec_size == bottom_mlp[-1] ), "Embedding vector size must equle to bottom MLP output size" - self.bottom_mlp = MLP(num_dense_fields, bottom_mlp, fused=use_fusedmlp) + self.num_dense_fields = ( + ((num_dense_fields + 8 - 1) // 8 * 8) if dense_input_padding else num_dense_fields + ) + self.pad = ( + [0, self.num_dense_fields - num_dense_fields] + if self.num_dense_fields > num_dense_fields + else None + ) + + self.bottom_mlp = MLP(self.num_dense_fields, bottom_mlp, fused=use_fusedmlp) self.embedding = OneEmbedding( embedding_vec_size, persistent_path, table_size_array, one_embedding_store_type, cache_memory_budget_mb, + one_embedding_key_type, ) self.interaction = Interaction( bottom_mlp[-1], @@ -352,6 +376,8 @@ def __init__( ) def forward(self, dense_fields, sparse_fields) -> flow.Tensor: + if self.pad: + dense_fields = flow.nn.functional.pad(dense_fields, self.pad, "constant") dense_fields = flow.log(dense_fields + 1.0) dense_fields = self.bottom_mlp(dense_fields) embedding = self.embedding(sparse_fields) @@ -368,9 +394,11 @@ def make_dlrm_module(args): persistent_path=args.persistent_path, table_size_array=args.table_size_array, one_embedding_store_type=args.store_type, + one_embedding_key_type=args.one_embedding_key_type, cache_memory_budget_mb=args.cache_memory_budget_mb, interaction_itself=args.interaction_itself, interaction_padding=not args.disable_interaction_padding, + dense_input_padding=not args.disable_dense_input_padding, ) return model @@ -622,7 +650,6 @@ def eval(cached_eval_batches, eval_graph, cur_step=0): if __name__ == "__main__": os.system(sys.executable + " -m oneflow --doctor") - os.system("env") flow.boxing.nccl.enable_all_to_all(True) args = get_args() From cc59c3d25ebe9deed5409895dba5227f56c78fe7 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 16 Jun 2022 14:07:55 +0800 Subject: [PATCH 16/34] add sh --- RecommenderSystems/dlrm/kaggle_nsys.sh | 32 +++++++++++++++++ RecommenderSystems/dlrm/kaggle_train.sh | 46 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100755 RecommenderSystems/dlrm/kaggle_nsys.sh create mode 100755 RecommenderSystems/dlrm/kaggle_train.sh diff --git a/RecommenderSystems/dlrm/kaggle_nsys.sh b/RecommenderSystems/dlrm/kaggle_nsys.sh new file mode 100755 index 000000000..4300ad3cf --- /dev/null +++ b/RecommenderSystems/dlrm/kaggle_nsys.sh @@ -0,0 +1,32 @@ +persistent=./persistent +rm -rf $persistent/* + +export CUDA_VISIBLE_DEVICES=1 +export ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 +export ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 +export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 +export ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 +#export ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 +export ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 +#export ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM=1 + +column_size_array='1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' + +/usr/local/cuda-11.6/bin/nsys profile --stats=true -o of24_1gpu_bsz6912_no-sys-gather_pad-dense-input \ +python3 -m oneflow.distributed.launch \ + --nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + dlrm_prefetch_train.py \ + --data_dir /data/criteo_kaggle/dlrm_parquet_int32 \ + --persistent_path $persistent \ + --store_type device_mem \ + --train_batches 300 \ + --train_batch_size 6912 \ + --learning_rate 3 \ + --table_size_array $column_size_array \ + --one_embedding_key_type int32 \ + --disable_dense_input_padding \ + --amp + #--train_batches 300 \ diff --git a/RecommenderSystems/dlrm/kaggle_train.sh b/RecommenderSystems/dlrm/kaggle_train.sh new file mode 100755 index 000000000..c32cf4c5d --- /dev/null +++ b/RecommenderSystems/dlrm/kaggle_train.sh @@ -0,0 +1,46 @@ +persistent=./persistent +rm -rf $persistent/* + +export CUDA_VISIBLE_DEVICES=1 +export ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 +export ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 +export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 +export ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 +#export ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 +export ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 +#export ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM=1 + +column_size_array='1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' + +python3 -m oneflow.distributed.launch \ + --nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + dlrm_prefetch_train.py \ + --data_dir /data/criteo_kaggle/dlrm_parquet_int32 \ + --persistent_path $persistent \ + --store_type device_mem \ + --train_batches 10000 \ + --train_batch_size 6912 \ + --learning_rate 3 \ + --table_size_array $column_size_array \ + --one_embedding_key_type int32 \ + --amp + +python3 -m oneflow.distributed.launch \ + --nproc_per_node 1 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + dlrm_prefetch_train.py \ + --data_dir /data/criteo_kaggle/dlrm_parquet_int32 \ + --persistent_path $persistent \ + --store_type device_mem \ + --train_batches 10000 \ + --train_batch_size 6912 \ + --learning_rate 3 \ + --table_size_array $column_size_array \ + --one_embedding_key_type int32 \ + --disable_dense_input_padding \ + --amp From a5e25c9e80bda89c417f70e22dff1114ec8dd075 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Fri, 17 Jun 2022 15:53:49 +0800 Subject: [PATCH 17/34] udpate --- RecommenderSystems/dlrm/kaggle_nsys.sh | 9 +++++---- RecommenderSystems/dlrm/kaggle_train.sh | 18 +----------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/RecommenderSystems/dlrm/kaggle_nsys.sh b/RecommenderSystems/dlrm/kaggle_nsys.sh index 4300ad3cf..147668e6d 100755 --- a/RecommenderSystems/dlrm/kaggle_nsys.sh +++ b/RecommenderSystems/dlrm/kaggle_nsys.sh @@ -1,7 +1,9 @@ +prefix=${1:-of24_1gpu_bsz6912} + persistent=./persistent -rm -rf $persistent/* +rm -rf ${prefix}.* $persistent/* -export CUDA_VISIBLE_DEVICES=1 +#export CUDA_VISIBLE_DEVICES=1 export ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 export ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 @@ -12,7 +14,7 @@ export ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 column_size_array='1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' -/usr/local/cuda-11.6/bin/nsys profile --stats=true -o of24_1gpu_bsz6912_no-sys-gather_pad-dense-input \ +/usr/local/cuda-11.6/bin/nsys profile --stats=true -o $prefix \ python3 -m oneflow.distributed.launch \ --nproc_per_node 1 \ --nnodes 1 \ @@ -27,6 +29,5 @@ python3 -m oneflow.distributed.launch \ --learning_rate 3 \ --table_size_array $column_size_array \ --one_embedding_key_type int32 \ - --disable_dense_input_padding \ --amp #--train_batches 300 \ diff --git a/RecommenderSystems/dlrm/kaggle_train.sh b/RecommenderSystems/dlrm/kaggle_train.sh index c32cf4c5d..bb88353e7 100755 --- a/RecommenderSystems/dlrm/kaggle_train.sh +++ b/RecommenderSystems/dlrm/kaggle_train.sh @@ -1,7 +1,7 @@ persistent=./persistent rm -rf $persistent/* -export CUDA_VISIBLE_DEVICES=1 +#export CUDA_VISIBLE_DEVICES=1 export ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 export ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 @@ -28,19 +28,3 @@ python3 -m oneflow.distributed.launch \ --one_embedding_key_type int32 \ --amp -python3 -m oneflow.distributed.launch \ - --nproc_per_node 1 \ - --nnodes 1 \ - --node_rank 0 \ - --master_addr 127.0.0.1 \ - dlrm_prefetch_train.py \ - --data_dir /data/criteo_kaggle/dlrm_parquet_int32 \ - --persistent_path $persistent \ - --store_type device_mem \ - --train_batches 10000 \ - --train_batch_size 6912 \ - --learning_rate 3 \ - --table_size_array $column_size_array \ - --one_embedding_key_type int32 \ - --disable_dense_input_padding \ - --amp From 7126d88902fc7031b82269615751e3a88511d14b Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 23 Jun 2022 16:05:22 +0800 Subject: [PATCH 18/34] update --- RecommenderSystems/dlrm/dlrm_prefetch_train.py | 4 ++-- RecommenderSystems/dlrm/dlrm_train_eval.py | 2 ++ RecommenderSystems/dlrm/kaggle_nsys.sh | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index 3f8f63ae5..41b31f653 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -376,9 +376,9 @@ def __init__( ) def forward(self, dense_fields, sparse_fields) -> flow.Tensor: + dense_fields = flow.log(dense_fields + 1.0) if self.pad: dense_fields = flow.nn.functional.pad(dense_fields, self.pad, "constant") - dense_fields = flow.log(dense_fields + 1.0) dense_fields = self.bottom_mlp(dense_fields) embedding = self.embedding(sparse_fields) features = self.interaction(dense_fields, embedding) @@ -597,9 +597,9 @@ def np_to_global(np): def batch_to_global(np_label, np_dense, np_sparse, is_train=True): - labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) dense_fields = np_to_global(np_dense) sparse_fields = np_to_global(np_sparse) + labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) return labels, dense_fields, sparse_fields diff --git a/RecommenderSystems/dlrm/dlrm_train_eval.py b/RecommenderSystems/dlrm/dlrm_train_eval.py index 0be0aaf9c..6ac8376a7 100644 --- a/RecommenderSystems/dlrm/dlrm_train_eval.py +++ b/RecommenderSystems/dlrm/dlrm_train_eval.py @@ -542,6 +542,8 @@ def save_model(subdir): f"Rank[{rank}], Step {step}, Loss {loss:0.4f}, Latency " + f"{(latency * 1000):0.3f} ms, Throughput {throughput:0.1f}, {strtime}" ) + if np.isnan(loss): + exit(1) if args.eval_interval > 0 and step % args.eval_interval == 0: auc = eval(cached_eval_batches, eval_graph, step) diff --git a/RecommenderSystems/dlrm/kaggle_nsys.sh b/RecommenderSystems/dlrm/kaggle_nsys.sh index 147668e6d..50a845fa2 100755 --- a/RecommenderSystems/dlrm/kaggle_nsys.sh +++ b/RecommenderSystems/dlrm/kaggle_nsys.sh @@ -11,6 +11,7 @@ export ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 #export ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 export ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 #export ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM=1 +export ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE=1 column_size_array='1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' From d7d34796c9411673ee9e2de6baf3c8d33195f7b5 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 23 Jun 2022 17:02:54 +0800 Subject: [PATCH 19/34] fix typo in mmoe_parquet.py; remove used import --- RecommenderSystems/mmoe/tools/mmoe_parquet.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py index 57582fa1b..b168e6347 100644 --- a/RecommenderSystems/mmoe/tools/mmoe_parquet.py +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -15,14 +15,12 @@ import argparse import pandas as pd -from sklearn.metrics import roc_auc_score -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import LabelEncoder, MinMaxScaler +from sklearn.preprocessing import MinMaxScaler from pyspark.sql import SparkSession from pyspark.conf import SparkConf from pyspark.sql.functions import rand, udf, lit, xxhash64 -from pyspark.sql.types import FloatType, LongType +from pyspark.sql.types import FloatType column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', @@ -104,7 +102,7 @@ def make_mmoe_parquet( args = parser.parse_args() test_csv = os.path.join(args.input_dir, "census-income.test") - train_csv = os.path.join(args.input_dir, "census-income.sample") + train_csv = os.path.join(args.input_dir, "census-income.data") # start spark session conf = SparkConf() From cbbcea1fa792b533b008c48a00015032c9be5bef Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 29 Jun 2022 10:22:05 +0800 Subject: [PATCH 20/34] eval steps --- RecommenderSystems/dlrm/dlrm_train_eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RecommenderSystems/dlrm/dlrm_train_eval.py b/RecommenderSystems/dlrm/dlrm_train_eval.py index 376f5532d..c6b558f19 100644 --- a/RecommenderSystems/dlrm/dlrm_train_eval.py +++ b/RecommenderSystems/dlrm/dlrm_train_eval.py @@ -74,6 +74,7 @@ def str_list(x): parser.add_argument("--eval_batches", type=int, default=1612, help="number of eval batches") parser.add_argument("--eval_batch_size", type=int, default=55296) parser.add_argument("--eval_interval", type=int, default=10000) + parser.add_argument("--eval_steps", type=int_list, default="58000,59000") parser.add_argument("--train_batch_size", type=int, default=55296) parser.add_argument("--learning_rate", type=float, default=24) parser.add_argument("--warmup_batches", type=int, default=2750) @@ -545,7 +546,7 @@ def save_model(subdir): if np.isnan(loss): exit(1) - if args.eval_interval > 0 and step % args.eval_interval == 0: + if (args.eval_interval > 0 and step % args.eval_interval == 0) or (step in args.eval_steps): auc = eval(cached_eval_batches, eval_graph, step) if args.save_model_after_each_eval: save_model(f"step_{step}_val_auc_{auc:0.5f}") From b8896dc67abbc2cc3604352168730531f0776553 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 29 Jun 2022 12:08:48 +0800 Subject: [PATCH 21/34] nsys 4gpus --- RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh | 32 +++++++++++++++++++ .../dlrm/dlrm_prefetch_train.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100755 RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh diff --git a/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh b/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh new file mode 100755 index 000000000..c70e14442 --- /dev/null +++ b/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh @@ -0,0 +1,32 @@ +prefix=${1:-of24_1gpu_bsz6912} + +persistent=./persistent +rm -rf ${prefix}.* $persistent/* + +#export CUDA_VISIBLE_DEVICES=1 +export ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 +export ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 +export ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 +export ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 +#export ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=1 +export ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 +#export ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM=1 +export ONEFLOW_PROFILER_KERNEL_PROFILE_KERNEL_FORWARD_RANGE=1 + + +/usr/local/cuda-11.6/bin/nsys profile --stats=true -o $prefix \ +python3 -m oneflow.distributed.launch \ + --nproc_per_node 4 \ + --nnodes 1 \ + --node_rank 0 \ + --master_addr 127.0.0.1 \ + dlrm_prefetch_train.py \ + --data_dir /RAID0/xiexuan/dlrm_parquet_int32 \ + --persistent_path $persistent \ + --store_type device_mem \ + --train_batches 300 \ + --train_batch_size 27648 \ + --learning_rate 3 \ + --one_embedding_key_type int32 \ + --amp + #--train_batches 300 \ diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index 41b31f653..56b650066 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -429,7 +429,7 @@ def make_lr_scheduler(args, optimizer): optimizer, start_factor=0, total_iters=args.warmup_batches, ) poly_decay_lr = flow.optim.lr_scheduler.PolynomialLR( - optimizer, steps=args.decay_batches, end_learning_rate=0, power=2.0, cycle=False, + optimizer, decay_batch=args.decay_batches, end_learning_rate=0, power=2.0, cycle=False, ) sequential_lr = flow.optim.lr_scheduler.SequentialLR( optimizer=optimizer, From 49fae88f3cdfc014b8bc9208cc819f07d3142f36 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 29 Jun 2022 12:11:11 +0800 Subject: [PATCH 22/34] update default file name --- RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh b/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh index c70e14442..560cacbdf 100755 --- a/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh +++ b/RecommenderSystems/dlrm/criteo1t_nsys_4gpu.sh @@ -1,4 +1,4 @@ -prefix=${1:-of24_1gpu_bsz6912} +prefix=${1:-4gpu_bsz27648} persistent=./persistent rm -rf ${prefix}.* $persistent/* From 9e0deeefc9feedbde0b856488b525d977368c243 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 29 Jun 2022 15:58:32 +0800 Subject: [PATCH 23/34] Update README.md (dataset); Update mmoe_train_eval.py to deal with empty str args; --- RecommenderSystems/mmoe/README.md | 22 +++++++++++++++++++++- RecommenderSystems/mmoe/mmoe_train_eval.py | 2 ++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md index ffd040b97..1572fbf61 100644 --- a/RecommenderSystems/mmoe/README.md +++ b/RecommenderSystems/mmoe/README.md @@ -1,10 +1,12 @@ # MMoE [Multi-gate Mixture-of-Experts (MMoE)](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007) adapts the Mixture-of- Experts (MoE) structure to multi-task learning by sharing the expert submodels across all tasks, while also having a gating network trained to optimize each task. Its model structure is as follows. Based on this structure, this project uses OneFlow distributed deep learning framework to realize training the model in graph mode on the Criteo data set. +

mmoe

+ ## Directory description ```txt @@ -69,6 +71,24 @@ A hands-on guide to train a MMoe model. ### Dataset +1. Download the [Census-Income dataset](https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz). We directly treat the 199,523 samples in the census-income.data as training examples and the 99,762 samples in the census-income.test as test examples. + + ```shell + wget https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census.tar.gz + ``` + +2. Download spark from https://spark.apache.org/downloads.html and then uncompress the tar file into the directory where you want to install Spark. Ensure the `SPARK_HOME` environment variable points to the directory where the spark is. + +3. Run ./tools/mmoe_parquet.py to generate the dataset. + + ```shell + python3 ./tools/mmoe_parquet.py \ + --input_dir /path/to/census_income \ + --output_dir /path/to/mmoe_parquet \ + --spark_tmp_dir /path/to/spark_tmp_dir \ + --export_dataset_info + ``` + ### Start Training by Oneflow 1. Modify the **train_mmoe.sh** as needed. @@ -101,5 +121,5 @@ A hands-on guide to train a MMoe model. --num_test_samples 99762 \ --model_save_dir $MODEL_SAVE_DIR ``` - + 2. train a MMoE model by `bash train_mmoe.sh`. diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index dcd0e65f7..2ff2c3727 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -20,6 +20,8 @@ def int_list(x): return list(map(int, x.split(","))) def str_list(x): + if x == "": + return [] return list(map(str, x.split(","))) parser = argparse.ArgumentParser() From cb4e51a715aeb9cad21dd39d77c187a8bfc320c5 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 30 Jun 2022 17:02:54 +0800 Subject: [PATCH 24/34] Remove sklearn and pandas dependency in mmoe_parquet.py --- RecommenderSystems/mmoe/tools/mmoe_parquet.py | 57 ++++++++++++------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py index b168e6347..39f171e23 100644 --- a/RecommenderSystems/mmoe/tools/mmoe_parquet.py +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -15,12 +15,15 @@ import argparse import pandas as pd -from sklearn.preprocessing import MinMaxScaler +# from sklearn.preprocessing import MinMaxScaler from pyspark.sql import SparkSession from pyspark.conf import SparkConf -from pyspark.sql.functions import rand, udf, lit, xxhash64 +from pyspark.sql.functions import rand, udf, lit, xxhash64, col from pyspark.sql.types import FloatType +from pyspark.ml import Pipeline +from pyspark.ml.feature import MinMaxScaler +from pyspark.ml.linalg import VectorAssembler column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', @@ -40,28 +43,40 @@ def make_mmoe_parquet( spark, input_files, output_dir, part_num=None, shuffle=False ): + data = spark.read.format("csv").option("header","false").load(input_files).toDF('age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', + 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', + 'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends', + 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', + 'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', + 'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', + 'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k') - data = pd.read_csv(input_files, header=None, names=column_names) + # transform label + data.withColumn("label_income", (col("income_50k")==" 50000+.").cast("int")) + data.withColumn("label_marital", (col("marital_stat")==" Never married").cast("int")) + data.drop(col("income_50k")) + data.drop(col("marital_stat")) - data['label_income'] = data['income_50k'].map({' - 50000.': 0, ' 50000+.': 1}) - data['label_marital'] = data['marital_stat'].apply(lambda x: 1 if x == ' Never married' else 0) - data.drop(labels=['income_50k', 'marital_stat'], axis=1, inplace=True) - - columns = data.columns.values.tolist() + columns = data.columns dense_features = [col for col in columns if col not in sparse_features and col not in ['label_income', 'label_marital']] - data[sparse_features] = data[sparse_features].fillna('-1', ) - data[dense_features] = data[dense_features].fillna(0, ) - mms = MinMaxScaler(feature_range=(0, 1)) - data[dense_features] = mms.fit_transform(data[dense_features]) + # deal with na value + data[sparse_features] = data[sparse_features].fillna('-1') + data[dense_features] = data[dense_features].fillna(0) + + # scale dense features + assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in dense_features] + scalers = [MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in dense_features] + pipeline = Pipeline(stages=assemblers + scalers) + scalerModel = pipeline.fit(data) + data = scalerModel.transform(data) + + dense_names = {x + "_scaled": x for x in dense_features} + data = data.select([f.col(c).alias(dense_names[c]) for c in dense_names.keys()] + sparse_features + ["label_income", "label_marital"]) start = time.time() - - df = spark.createDataFrame(data) - columns_new = dense_features + sparse_features + ["label_income", "label_marital"] - df = df.select(columns_new) make_label = udf(lambda s: float(s), FloatType()) label_cols = [make_label(field).alias(field) for field in ["label_income", "label_marital"]] @@ -72,17 +87,17 @@ def make_mmoe_parquet( dense_cols = [make_dense(field).alias(field) for field in dense_features] cols = dense_cols + sparse_cols + label_cols - df = df.select(cols) + data = data.select(cols) if shuffle: - df = df.orderBy(rand()) + data = data.orderBy(rand()) if part_num: - df = df.repartition(part_num) + data = data.repartition(part_num) - df.write.mode("overwrite").parquet(output_dir) + data.write.mode("overwrite").parquet(output_dir) num_examples = spark.read.parquet(output_dir).count() print(output_dir, num_examples, f"time elapsed: {time.time()-start:0.1f}") - return num_examples, columns_new + return num_examples if __name__ == "__main__": From 30e65b818e0c9e426de7d55d877d575efbeb0ef1 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 30 Jun 2022 18:06:18 +0800 Subject: [PATCH 25/34] Fix bugs in mmoe_parquet.py --- RecommenderSystems/mmoe/tools/mmoe_parquet.py | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py index 39f171e23..d94f191e1 100644 --- a/RecommenderSystems/mmoe/tools/mmoe_parquet.py +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -14,16 +14,12 @@ import time import argparse -import pandas as pd -# from sklearn.preprocessing import MinMaxScaler - from pyspark.sql import SparkSession from pyspark.conf import SparkConf from pyspark.sql.functions import rand, udf, lit, xxhash64, col from pyspark.sql.types import FloatType from pyspark.ml import Pipeline -from pyspark.ml.feature import MinMaxScaler -from pyspark.ml.linalg import VectorAssembler +from pyspark.ml.feature import MinMaxScaler, VectorAssembler column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', @@ -43,6 +39,8 @@ def make_mmoe_parquet( spark, input_files, output_dir, part_num=None, shuffle=False ): + start = time.time() + data = spark.read.format("csv").option("header","false").load(input_files).toDF('age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends', @@ -52,19 +50,27 @@ def make_mmoe_parquet( 'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k') # transform label - data.withColumn("label_income", (col("income_50k")==" 50000+.").cast("int")) - data.withColumn("label_marital", (col("marital_stat")==" Never married").cast("int")) - data.drop(col("income_50k")) - data.drop(col("marital_stat")) + data = data.withColumn("label_income", (col("income_50k")==" 50000+.").cast("int")).drop(col("income_50k")) + data = data.withColumn("label_marital", (col("marital_stat")==" Never married").cast("int")).drop(col("marital_stat")) + # transform dense, sparse, label columns = data.columns - dense_features = [col for col in columns if - col not in sparse_features and col not in ['label_income', 'label_marital']] + dense_features = [col_ for col_ in columns if + col_ not in sparse_features and col_ not in ['label_income', 'label_marital']] + + data.na.fill(value=0,subset=dense_features) + data.na.fill(value='-1',subset=sparse_features) + + make_dense = udf(lambda s: float(s), FloatType()) + dense_cols = [make_dense(field).alias(field) for field in dense_features] + + make_label = udf(lambda s: float(s), FloatType()) + label_cols = [make_label(field).alias(field) for field in ["label_income", "label_marital"]] + + sparse_cols = [xxhash64(field, lit(i)).alias(field) for i, field in enumerate(sparse_features)] - # deal with na value - data[sparse_features] = data[sparse_features].fillna('-1') - data[dense_features] = data[dense_features].fillna(0) + data = data.select(dense_cols + sparse_cols + label_cols) # scale dense features assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in dense_features] @@ -73,21 +79,9 @@ def make_mmoe_parquet( scalerModel = pipeline.fit(data) data = scalerModel.transform(data) - dense_names = {x + "_scaled": x for x in dense_features} - data = data.select([f.col(c).alias(dense_names[c]) for c in dense_names.keys()] + sparse_features + ["label_income", "label_marital"]) - - start = time.time() - - make_label = udf(lambda s: float(s), FloatType()) - label_cols = [make_label(field).alias(field) for field in ["label_income", "label_marital"]] - - sparse_cols = [xxhash64(field, lit(i)).alias(field) for i, field in enumerate(sparse_features)] - - make_dense = udf(lambda s: float(s), FloatType()) - dense_cols = [make_dense(field).alias(field) for field in dense_features] - - cols = dense_cols + sparse_cols + label_cols - data = data.select(cols) + scaled_dense_names = {x + "_scaled": x for x in dense_features} + vec_to_float = udf(lambda v:float(v[0]),FloatType()) + data = data.select([vec_to_float(c).alias(scaled_dense_names[c]) for c in scaled_dense_names.keys()] + sparse_features + ["label_income", "label_marital"]) if shuffle: data = data.orderBy(rand()) @@ -127,13 +121,13 @@ def make_mmoe_parquet( # create test dataset test_output_dir = os.path.join(args.output_dir, "test") - test_count, _ = make_mmoe_parquet( + test_count = make_mmoe_parquet( spark, test_csv, test_output_dir, part_num=32 ) # create train dataset train_output_dir = os.path.join(args.output_dir, "train") - train_count, columns = make_mmoe_parquet( + train_count = make_mmoe_parquet( spark, train_csv, train_output_dir, part_num=64, shuffle=True ) From 13b0318a5573e78cb5ee2bf438d33e89e9f49521 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Thu, 30 Jun 2022 19:35:29 +0800 Subject: [PATCH 26/34] Simplify mmoe_parquet --- RecommenderSystems/mmoe/tools/mmoe_parquet.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py index d94f191e1..56c71dcb9 100644 --- a/RecommenderSystems/mmoe/tools/mmoe_parquet.py +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -41,13 +41,7 @@ def make_mmoe_parquet( ): start = time.time() - data = spark.read.format("csv").option("header","false").load(input_files).toDF('age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', - 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', - 'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends', - 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', - 'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', - 'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', - 'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k') + data = spark.read.format("csv").option("header","false").load(input_files).toDF(*column_names) # transform label data = data.withColumn("label_income", (col("income_50k")==" 50000+.").cast("int")).drop(col("income_50k")) From 39f84faeb322a0e22698e75a779531d6787a6b0b Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Mon, 4 Jul 2022 16:41:03 +0800 Subject: [PATCH 27/34] Update readme --- RecommenderSystems/mmoe/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/RecommenderSystems/mmoe/README.md b/RecommenderSystems/mmoe/README.md index 1572fbf61..a51777b5f 100644 --- a/RecommenderSystems/mmoe/README.md +++ b/RecommenderSystems/mmoe/README.md @@ -65,8 +65,6 @@ A hands-on guide to train a MMoe model. ```json psutil petastorm - pandas - sklearn ``` ### Dataset From aa7d714e6c8b0b0e4559d7af6b282376c7c649a1 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Mon, 4 Jul 2022 16:58:06 +0800 Subject: [PATCH 28/34] format mmoe_train_eval.py --- RecommenderSystems/mmoe/mmoe_train_eval.py | 29 +++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index 2ff2c3727..b5aed22ed 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -67,7 +67,7 @@ def str_list(x): parser.add_argument( "--train_batches", type=int, default=16000, help="the maximum number of training batches" ) - parser.add_argument("--loss_print_interval", type=int, default=100, help="") + parser.add_argument("--loss_print_interval", type=int, default=100, help="interval of printing loss") parser.add_argument( "--table_size_array", @@ -300,7 +300,9 @@ def __init__( ] if store_type == "device_mem": store_options = flow.one_embedding.make_device_mem_store_options( - persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, ) elif store_type == "cached_host_mem": assert cache_memory_budget_mb > 0 @@ -475,7 +477,13 @@ def make_mmoe_module(args): class MmoeTrainGraph(flow.nn.Graph): def __init__( - self, mmoe_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, + self, + mmoe_module, + loss, + optimizer, + grad_scaler=None, + amp=False, + lr_scheduler=None, ): super(MmoeTrainGraph, self).__init__() self.module = mmoe_module @@ -516,7 +524,9 @@ def make_lr_scheduler(args, optimizer): for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) ] multistep_lr = flow.optim.lr_scheduler.MultiStepLR( - optimizer=optimizer, milestones=milestones, gamma=args.lr_factor, + optimizer=optimizer, + milestones=milestones, + gamma=args.lr_factor, ) return multistep_lr @@ -562,7 +572,10 @@ def save_model(subdir): grad_scaler = flow.amp.StaticGradScaler(1024) else: grad_scaler = flow.amp.GradScaler( - init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + init_scale=1073741824, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, ) eval_graph = MmoeValGraph(mmoe_module, args.amp) @@ -617,7 +630,11 @@ def save_model(subdir): if step % batches_per_epoch != 0: auc_income, auc_marital = eval( - args, eval_graph, cur_step=step, epoch=epoch, cached_eval_batches=cached_eval_batches, + args, + eval_graph, + cur_step=step, + epoch=epoch, + cached_eval_batches=cached_eval_batches, ) if args.save_model_after_each_eval: save_model(f"step_{step}_val_auc_income_{auc_income:0.5f}_marital_{auc_marital:0.5f}") From 48ab7658a9893f9cbce8d540776185fcf92f1cc0 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Mon, 4 Jul 2022 17:01:07 +0800 Subject: [PATCH 29/34] Format mmoe_parquet.py --- RecommenderSystems/mmoe/tools/mmoe_parquet.py | 150 +++++++++++++----- 1 file changed, 111 insertions(+), 39 deletions(-) diff --git a/RecommenderSystems/mmoe/tools/mmoe_parquet.py b/RecommenderSystems/mmoe/tools/mmoe_parquet.py index 56c71dcb9..8a423a84e 100644 --- a/RecommenderSystems/mmoe/tools/mmoe_parquet.py +++ b/RecommenderSystems/mmoe/tools/mmoe_parquet.py @@ -21,62 +21,138 @@ from pyspark.ml import Pipeline from pyspark.ml.feature import MinMaxScaler, VectorAssembler -column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college', - 'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', - 'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends', - 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', - 'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', - 'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', - 'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k'] - -sparse_features = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code', - 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason', - 'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', - 'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', - 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship', - 'vet_question'] - -def make_mmoe_parquet( - spark, input_files, output_dir, part_num=None, shuffle=False -): +column_names = [ + "age", + "class_worker", + "det_ind_code", + "det_occ_code", + "education", + "wage_per_hour", + "hs_college", + "marital_stat", + "major_ind_code", + "major_occ_code", + "race", + "hisp_origin", + "sex", + "union_member", + "unemp_reason", + "full_or_part_emp", + "capital_gains", + "capital_losses", + "stock_dividends", + "tax_filer_stat", + "region_prev_res", + "state_prev_res", + "det_hh_fam_stat", + "det_hh_summ", + "instance_weight", + "mig_chg_msa", + "mig_chg_reg", + "mig_move_reg", + "mig_same", + "mig_prev_sunbelt", + "num_emp", + "fam_under_18", + "country_father", + "country_mother", + "country_self", + "citizenship", + "own_or_self", + "vet_question", + "vet_benefits", + "weeks_worked", + "year", + "income_50k", +] + +sparse_features = [ + "class_worker", + "det_ind_code", + "det_occ_code", + "education", + "hs_college", + "major_ind_code", + "major_occ_code", + "race", + "hisp_origin", + "sex", + "union_member", + "unemp_reason", + "full_or_part_emp", + "tax_filer_stat", + "region_prev_res", + "state_prev_res", + "det_hh_fam_stat", + "det_hh_summ", + "mig_chg_msa", + "mig_chg_reg", + "mig_move_reg", + "mig_same", + "mig_prev_sunbelt", + "fam_under_18", + "country_father", + "country_mother", + "country_self", + "citizenship", + "vet_question", +] + + +def make_mmoe_parquet(spark, input_files, output_dir, part_num=None, shuffle=False): start = time.time() - data = spark.read.format("csv").option("header","false").load(input_files).toDF(*column_names) - + data = spark.read.format("csv").option("header", "false").load(input_files).toDF(*column_names) + # transform label - data = data.withColumn("label_income", (col("income_50k")==" 50000+.").cast("int")).drop(col("income_50k")) - data = data.withColumn("label_marital", (col("marital_stat")==" Never married").cast("int")).drop(col("marital_stat")) + data = data.withColumn("label_income", (col("income_50k") == " 50000+.").cast("int")).drop( + col("income_50k") + ) + data = data.withColumn( + "label_marital", (col("marital_stat") == " Never married").cast("int") + ).drop(col("marital_stat")) # transform dense, sparse, label columns = data.columns - - dense_features = [col_ for col_ in columns if - col_ not in sparse_features and col_ not in ['label_income', 'label_marital']] - data.na.fill(value=0,subset=dense_features) - data.na.fill(value='-1',subset=sparse_features) + dense_features = [ + col_ + for col_ in columns + if col_ not in sparse_features and col_ not in ["label_income", "label_marital"] + ] + + data.na.fill(value=0, subset=dense_features) + data.na.fill(value="-1", subset=sparse_features) make_dense = udf(lambda s: float(s), FloatType()) dense_cols = [make_dense(field).alias(field) for field in dense_features] make_label = udf(lambda s: float(s), FloatType()) label_cols = [make_label(field).alias(field) for field in ["label_income", "label_marital"]] - + sparse_cols = [xxhash64(field, lit(i)).alias(field) for i, field in enumerate(sparse_features)] data = data.select(dense_cols + sparse_cols + label_cols) # scale dense features - assemblers = [VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in dense_features] - scalers = [MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in dense_features] + assemblers = [ + VectorAssembler(inputCols=[col], outputCol=col + "_vec") for col in dense_features + ] + scalers = [ + MinMaxScaler(inputCol=col + "_vec", outputCol=col + "_scaled") for col in dense_features + ] pipeline = Pipeline(stages=assemblers + scalers) scalerModel = pipeline.fit(data) data = scalerModel.transform(data) scaled_dense_names = {x + "_scaled": x for x in dense_features} - vec_to_float = udf(lambda v:float(v[0]),FloatType()) - data = data.select([vec_to_float(c).alias(scaled_dense_names[c]) for c in scaled_dense_names.keys()] + sparse_features + ["label_income", "label_marital"]) - + vec_to_float = udf(lambda v: float(v[0]), FloatType()) + data = data.select( + [vec_to_float(c).alias(scaled_dense_names[c]) for c in scaled_dense_names.keys()] + + sparse_features + + ["label_income", "label_marital"] + ) + if shuffle: data = data.orderBy(rand()) if part_num: @@ -115,15 +191,11 @@ def make_mmoe_parquet( # create test dataset test_output_dir = os.path.join(args.output_dir, "test") - test_count = make_mmoe_parquet( - spark, test_csv, test_output_dir, part_num=32 - ) + test_count = make_mmoe_parquet(spark, test_csv, test_output_dir, part_num=32) # create train dataset train_output_dir = os.path.join(args.output_dir, "train") - train_count = make_mmoe_parquet( - spark, train_csv, train_output_dir, part_num=64, shuffle=True - ) + train_count = make_mmoe_parquet(spark, train_csv, train_output_dir, part_num=64, shuffle=True) if args.export_dataset_info: df = spark.read.parquet(train_output_dir, test_output_dir) From 6ca5f11a750d488643dd78cfc5812843bcba8f74 Mon Sep 17 00:00:00 2001 From: Liuxinman Date: Wed, 6 Jul 2022 11:13:01 +0800 Subject: [PATCH 30/34] Remove num_sparse_features and num_dense_features --- RecommenderSystems/mmoe/mmoe_train_eval.py | 127 +++++++++------------ 1 file changed, 54 insertions(+), 73 deletions(-) diff --git a/RecommenderSystems/mmoe/mmoe_train_eval.py b/RecommenderSystems/mmoe/mmoe_train_eval.py index b5aed22ed..9acc5f44f 100644 --- a/RecommenderSystems/mmoe/mmoe_train_eval.py +++ b/RecommenderSystems/mmoe/mmoe_train_eval.py @@ -67,7 +67,9 @@ def str_list(x): parser.add_argument( "--train_batches", type=int, default=16000, help="the maximum number of training batches" ) - parser.add_argument("--loss_print_interval", type=int, default=100, help="interval of printing loss") + parser.add_argument( + "--loss_print_interval", type=int, default=100, help="interval of printing loss" + ) parser.add_argument( "--table_size_array", @@ -115,8 +117,50 @@ def _print_args(args): print("-------------------- end of arguments ---------------------", flush=True) -num_dense_fields = 11 -num_sparse_fields = 29 +sparse_features = [ + "class_worker", + "det_ind_code", + "det_occ_code", + "education", + "hs_college", + "major_ind_code", + "major_occ_code", + "race", + "hisp_origin", + "sex", + "union_member", + "unemp_reason", + "full_or_part_emp", + "tax_filer_stat", + "region_prev_res", + "state_prev_res", + "det_hh_fam_stat", + "det_hh_summ", + "mig_chg_msa", + "mig_chg_reg", + "mig_move_reg", + "mig_same", + "mig_prev_sunbelt", + "fam_under_18", + "country_father", + "country_mother", + "country_self", + "citizenship", + "vet_question", +] +dense_features = [ + "age", + "wage_per_hour", + "capital_gains", + "capital_losses", + "stock_dividends", + "instance_weight", + "num_emp", + "own_or_self", + "vet_benefits", + "weeks_worked", + "year", +] class MmoeDataReader(object): @@ -142,52 +186,6 @@ def __init__( self.shard_count = shard_count self.cur_shard = cur_shard - sparse_features = [ - "class_worker", - "det_ind_code", - "det_occ_code", - "education", - "hs_college", - "major_ind_code", - "major_occ_code", - "race", - "hisp_origin", - "sex", - "union_member", - "unemp_reason", - "full_or_part_emp", - "tax_filer_stat", - "region_prev_res", - "state_prev_res", - "det_hh_fam_stat", - "det_hh_summ", - "mig_chg_msa", - "mig_chg_reg", - "mig_move_reg", - "mig_same", - "mig_prev_sunbelt", - "fam_under_18", - "country_father", - "country_mother", - "country_self", - "citizenship", - "vet_question", - ] - - dense_features = [ - "age", - "wage_per_hour", - "capital_gains", - "capital_losses", - "stock_dividends", - "instance_weight", - "num_emp", - "own_or_self", - "vet_benefits", - "weeks_worked", - "year", - ] - self.fields = dense_features + sparse_features + ["label_income", "label_marital"] self.dense_end = len(dense_features) @@ -300,9 +298,7 @@ def __init__( ] if store_type == "device_mem": store_options = flow.one_embedding.make_device_mem_store_options( - persistent_path=persistent_path, - capacity=vocab_size, - size_factor=size_factor, + persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, ) elif store_type == "cached_host_mem": assert cache_memory_budget_mb > 0 @@ -404,7 +400,7 @@ def __init__( self.experts = nn.ModuleList([]) for _ in range(num_experts): expert_net = DNN( - in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + in_features=embedding_vec_size * len(sparse_features) + len(dense_features), hidden_units=expert_dnn[:-1], out_features=expert_dnn[-1], skip_final_activation=True, @@ -416,7 +412,7 @@ def __init__( self.towers = nn.ModuleList([]) for _ in range(num_tasks): gate_net = DNN( - in_features=embedding_vec_size * num_sparse_fields + num_dense_fields, + in_features=embedding_vec_size * len(sparse_features) + len(dense_features), hidden_units=gate_dnn, out_features=num_experts, skip_final_activation=True, @@ -477,13 +473,7 @@ def make_mmoe_module(args): class MmoeTrainGraph(flow.nn.Graph): def __init__( - self, - mmoe_module, - loss, - optimizer, - grad_scaler=None, - amp=False, - lr_scheduler=None, + self, mmoe_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, ): super(MmoeTrainGraph, self).__init__() self.module = mmoe_module @@ -524,9 +514,7 @@ def make_lr_scheduler(args, optimizer): for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) ] multistep_lr = flow.optim.lr_scheduler.MultiStepLR( - optimizer=optimizer, - milestones=milestones, - gamma=args.lr_factor, + optimizer=optimizer, milestones=milestones, gamma=args.lr_factor, ) return multistep_lr @@ -572,10 +560,7 @@ def save_model(subdir): grad_scaler = flow.amp.StaticGradScaler(1024) else: grad_scaler = flow.amp.GradScaler( - init_scale=1073741824, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, + init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ) eval_graph = MmoeValGraph(mmoe_module, args.amp) @@ -630,11 +615,7 @@ def save_model(subdir): if step % batches_per_epoch != 0: auc_income, auc_marital = eval( - args, - eval_graph, - cur_step=step, - epoch=epoch, - cached_eval_batches=cached_eval_batches, + args, eval_graph, cur_step=step, epoch=epoch, cached_eval_batches=cached_eval_batches, ) if args.save_model_after_each_eval: save_model(f"step_{step}_val_auc_income_{auc_income:0.5f}_marital_{auc_marital:0.5f}") From 4ced5dcabbdd4fa6f7e9ef2330e3420413667d62 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 12 Jul 2022 13:19:43 +0800 Subject: [PATCH 31/34] env tests --- .../dlrm/dlrm_prefetch_train.py | 3 + RecommenderSystems/dlrm/prefetch_train.py | 89 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 RecommenderSystems/dlrm/prefetch_train.py diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index 56b650066..7281ae85a 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -93,6 +93,7 @@ def str_list(x): parser.add_argument("--store_type", type=str, default="cached_host_mem") parser.add_argument("--cache_memory_budget_mb", type=int, default=8192) parser.add_argument("--amp", action="store_true", help="Run model with amp") + parser.add_argument("--split_allreduce", action="store_true", help="split bottom and top allreduce") parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") args = parser.parse_args() @@ -652,5 +653,7 @@ def eval(cached_eval_batches, eval_graph, cur_step=0): os.system(sys.executable + " -m oneflow --doctor") flow.boxing.nccl.enable_all_to_all(True) args = get_args() + if args.split_allreduce: + flow.boxing.nccl.set_fusion_max_ops_num(10) train(args) diff --git a/RecommenderSystems/dlrm/prefetch_train.py b/RecommenderSystems/dlrm/prefetch_train.py new file mode 100644 index 000000000..d3bdda33b --- /dev/null +++ b/RecommenderSystems/dlrm/prefetch_train.py @@ -0,0 +1,89 @@ +import os +import sys +import argparse +import datetime +from dateutil import tz + +num_gpus = 4 +persistent_path = './persistent' +table_size_array = [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, 2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295, 39664984, 585935, 12972, 108, 36] +table_size_array = ','.join([str(i) for i in table_size_array]) +num_eval_examples = 89137319 +eval_batch_size = 55296 +eval_batchs= num_eval_examples // eval_batch_size +warmup_batches = 2500 +decay_batches = 15406 +train_batch_size = num_gpus * 6912 +#train_batch_size = 69120 +num_train_samples = 4195197692 +train_batches = num_train_samples // train_batch_size + 1 +decay_start = train_batches - decay_batches + 3700 + +env = "" +env += "ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=0 " +env += "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 " +env += "ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 " +env += "ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 " +env += "ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 " +env += "ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 " + +cfg = "" +cfg += "--eval_interval 100000 " +cfg += "--model_save_dir ckpt " +cfg += "--one_embedding_key_type int32 " +cfg += f"--data_dir /RAID0/xiexuan/dlrm_parquet_int32 " +cfg += f"--persistent_path {persistent_path} " +cfg += f"--store_type device_mem " +cfg += f"--table_size_array {table_size_array} " +cfg += f"--train_batch_size {train_batch_size} " +#cfg += f"--train_batches {train_batches} " +cfg += f"--train_batches 10000 " +cfg += f"--eval_batches {eval_batchs} " +cfg += f"--eval_batch_size {eval_batch_size} " +cfg += f"--warmup_batches {warmup_batches} " +cfg += f"--decay_start {decay_start} " +cfg += f"--decay_batches {decay_batches} " +cfg += f"--amp " + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="flags for OneFlow DLRM") + parser.add_argument("--log_path", type=str, default="commits.log") + args = parser.parse_args() + ext_envs = [ + "ONEFLOW_GRAPH_PLACE_TRAINING_STATE_ON_ALL_RANKS", + "ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_INDEPENTENT_STREAM", + "ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM", + "ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION", + "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", + "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", + "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_UNABLE_ALLREDUCE", + ] + for i in range(10): + # test baseline + cmd = env + dl + cfg + os.system(f'rm -rf {persistent_path}*') + os.system(f'echo {cmd}') + os.system(cmd + f" | tee baseline_{i}.log") + + # test split allreduce + cmd = env + dl + cfg + "--split_allreduce " + os.system(f'rm -rf {persistent_path}*') + os.system(f'echo {cmd}') + os.system(cmd + f" | tee split_allreduce_{i}.log") + + # test envs + for ext_env in ext_envs: + test_name = ext_env + dl = sys.executable + " -m oneflow.distributed.launch " + dl += f"--nproc_per_node {num_gpus} " + dl += "--nnodes 1 " + dl += "--node_rank 0 " + dl += "--master_addr 127.0.0.1 " + dl += "dlrm_prefetch_train.py " + + cmd = env + ext_env + "=1 " + dl + cfg + os.system(f'rm -rf {persistent_path}*') + os.system(f'echo {cmd}') + os.system(cmd + f" | tee {test_name}_{i}.log") + + From e5fca3121d12a7e4e16194debb14a7997e7bbc88 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 12 Jul 2022 17:07:49 +0800 Subject: [PATCH 32/34] update --- RecommenderSystems/dlrm/prefetch_train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/RecommenderSystems/dlrm/prefetch_train.py b/RecommenderSystems/dlrm/prefetch_train.py index d3bdda33b..f74a1b4e3 100644 --- a/RecommenderSystems/dlrm/prefetch_train.py +++ b/RecommenderSystems/dlrm/prefetch_train.py @@ -45,6 +45,15 @@ cfg += f"--decay_batches {decay_batches} " cfg += f"--amp " + +dl = sys.executable + " -m oneflow.distributed.launch " +dl += f"--nproc_per_node {num_gpus} " +dl += "--nnodes 1 " +dl += "--node_rank 0 " +dl += "--master_addr 127.0.0.1 " +dl += "dlrm_prefetch_train.py " + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="flags for OneFlow DLRM") parser.add_argument("--log_path", type=str, default="commits.log") @@ -74,13 +83,6 @@ # test envs for ext_env in ext_envs: test_name = ext_env - dl = sys.executable + " -m oneflow.distributed.launch " - dl += f"--nproc_per_node {num_gpus} " - dl += "--nnodes 1 " - dl += "--node_rank 0 " - dl += "--master_addr 127.0.0.1 " - dl += "dlrm_prefetch_train.py " - cmd = env + ext_env + "=1 " + dl + cfg os.system(f'rm -rf {persistent_path}*') os.system(f'echo {cmd}') From 39f6cac025a2c9fec8df4017ada50ac1cca57f10 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 12 Jul 2022 19:51:07 +0800 Subject: [PATCH 33/34] update --- RecommenderSystems/dlrm/dlrm_prefetch_train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index 7281ae85a..eb0ce5fee 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -377,9 +377,9 @@ def __init__( ) def forward(self, dense_fields, sparse_fields) -> flow.Tensor: - dense_fields = flow.log(dense_fields + 1.0) if self.pad: dense_fields = flow.nn.functional.pad(dense_fields, self.pad, "constant") + dense_fields = flow.log(dense_fields + 1.0) dense_fields = self.bottom_mlp(dense_fields) embedding = self.embedding(sparse_fields) features = self.interaction(dense_fields, embedding) @@ -471,7 +471,8 @@ def __init__( def build(self, labels, dense_fields, sparse_fields): logits = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) loss = self.loss(logits, labels.to("cuda")) - reduce_loss = flow.mean(loss) + #reduce_loss = flow.mean(loss) + reduce_loss = loss reduce_loss.backward() return reduce_loss.to("cpu") @@ -510,7 +511,7 @@ def save_model(subdir): opt = flow.optim.SGD(dlrm_module.parameters(), lr=args.learning_rate) lr_scheduler = make_lr_scheduler(args, opt) - loss = flow.nn.BCEWithLogitsLoss(reduction="none").to("cuda") + loss = flow.nn.BCEWithLogitsLoss(reduction="mean").to("cuda") if args.loss_scale_policy == "static": grad_scaler = flow.amp.StaticGradScaler(1024) From 68cd3d0f517e99b6a954fb44dc7fd56549f2dab6 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Fri, 15 Jul 2022 00:11:25 +0800 Subject: [PATCH 34/34] update --- .../dlrm/dlrm_prefetch_train.py | 13 +++++++++---- RecommenderSystems/dlrm/prefetch_train.py | 19 ++++++++++++------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/RecommenderSystems/dlrm/dlrm_prefetch_train.py b/RecommenderSystems/dlrm/dlrm_prefetch_train.py index eb0ce5fee..7c41ccc36 100644 --- a/RecommenderSystems/dlrm/dlrm_prefetch_train.py +++ b/RecommenderSystems/dlrm/dlrm_prefetch_train.py @@ -93,6 +93,7 @@ def str_list(x): parser.add_argument("--store_type", type=str, default="cached_host_mem") parser.add_argument("--cache_memory_budget_mb", type=int, default=8192) parser.add_argument("--amp", action="store_true", help="Run model with amp") + parser.add_argument("--prefetch_cuda", action="store_true", help="prefetch data to cuda") parser.add_argument("--split_allreduce", action="store_true", help="split bottom and top allreduce") parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") @@ -469,8 +470,8 @@ def __init__( self.set_grad_scaler(grad_scaler) def build(self, labels, dense_fields, sparse_fields): - logits = self.module(dense_fields.to("cuda"), sparse_fields.to("cuda")) - loss = self.loss(logits, labels.to("cuda")) + logits = self.module(dense_fields if dense_fields.is_cuda else dense_fields.to("cuda"), sparse_fields if sparse_fields.is_cuda else sparse_fields.to("cuda")) + loss = self.loss(logits, labels if labels.is_cuda else labels.to("cuda")) #reduce_loss = flow.mean(loss) reduce_loss = loss reduce_loss.backward() @@ -544,7 +545,7 @@ def save_model(subdir): with make_criteo_dataloader(f"{args.data_dir}/train", args.train_batch_size) as loader: print('start prefetch training data...') - cached_batches = [batch_to_global(*next(loader)) for _ in range(args.train_batches)] + cached_batches = [batch_to_global(*next(loader), to_cuda=args.prefetch_cuda) for _ in range(args.train_batches)] print('start training ..') step, last_step, last_time = 0, 0, time.time() for labels, dense_fields, sparse_fields in cached_batches: @@ -598,10 +599,14 @@ def np_to_global(np): return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) -def batch_to_global(np_label, np_dense, np_sparse, is_train=True): +def batch_to_global(np_label, np_dense, np_sparse, is_train=True, to_cuda=False): dense_fields = np_to_global(np_dense) sparse_fields = np_to_global(np_sparse) labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + if to_cuda: + labels = labels.to("cuda") + dense_fields = dense_fields.to("cuda") + sparse_fields = sparse_fields.to("cuda") return labels, dense_fields, sparse_fields diff --git a/RecommenderSystems/dlrm/prefetch_train.py b/RecommenderSystems/dlrm/prefetch_train.py index f74a1b4e3..27989d032 100644 --- a/RecommenderSystems/dlrm/prefetch_train.py +++ b/RecommenderSystems/dlrm/prefetch_train.py @@ -20,11 +20,13 @@ decay_start = train_batches - decay_batches + 3700 env = "" -env += "ONEFLOW_ONE_EMBEDDING_ENABLE_QUANTIZED_COMM=0 " -env += "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 " -env += "ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 " +env += "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD=0 " +env += "ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION=1 " +env += "ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16=1 " env += "ONEFLOW_FUSE_MODEL_UPDATE_CAST=1 " env += "ONEFLOW_ENABLE_MULTI_TENSOR_MODEL_UPDATE=1 " +env += "ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH=1 " +env += "ONEFLOW_EAGER_LOCAL_TO_GLOBAL_BALANCED_OVERRIDE=1 " env += "ONEFLOW_ONE_EMBEDDING_USE_SYSTEM_GATHER=0 " cfg = "" @@ -60,12 +62,9 @@ args = parser.parse_args() ext_envs = [ "ONEFLOW_GRAPH_PLACE_TRAINING_STATE_ON_ALL_RANKS", - "ONEFLOW_ONE_EMBEDDING_EMBEDDING_GRADIENT_SHUFFLE_INDEPENTENT_STREAM", - "ONEFLOW_ONE_EMBEDDING_EMBEDDING_SHUFFLE_INDEPENTENT_STREAM", - "ONEFLOW_ONE_EMBEDDING_FUSE_EMBEDDING_INTERACTION", - "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", "ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_UNABLE_ALLREDUCE", + "LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 " ] for i in range(10): # test baseline @@ -74,6 +73,12 @@ os.system(f'echo {cmd}') os.system(cmd + f" | tee baseline_{i}.log") + # test split allreduce + cmd = env + dl + cfg + "--prefetch_cuda" + os.system(f'rm -rf {persistent_path}*') + os.system(f'echo {cmd}') + os.system(cmd + f" | tee prefetch_cuda{i}.log") + # test split allreduce cmd = env + dl + cfg + "--split_allreduce " os.system(f'rm -rf {persistent_path}*')