Skip to content

Commit

Permalink
Merge pull request #59 from FederatedAI/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
mgqa34 authored Jun 8, 2020
2 parents 1717021 + b75e538 commit 47a67a3
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 80 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Release 1.3.0
## Major Features and Improvements
* Hetero Secureboosting communication optimization: communication round is reduced to 1 by letting the host send a pre-computed host node route, which is used for inferencing, to the guest.

# Release 1.2.0
## Major Features and Improvements
* Replace serving-router with a brand new service called serving-proxy, which supports authentication and inference request with HTTP or gRPC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ public class Dict {
public static final String INPUT_DATA_HIT_RATE = "inputDataHitRate";
public static final String GUEST_MODEL_WEIGHT_HIT_RATE = "guestModelWeightHitRate";
public static final String GUEST_INPUT_DATA_HIT_RATE = "guestInputDataHitRate";
public static final String TAG_INPUT_FORMAT = "tag";
public static final String SPARSE_INPUT_FORMAT = "sparse";
public static final String MIN_MAX_SCALE = "min_max_scale";
public static final String STANDARD_SCALE = "standard_scale";
public static final String DSL_COMPONENTS = "components";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@
import com.webank.ai.fate.core.mlmodel.buffer.DataIOMetaProto.DataIOMeta;
import com.webank.ai.fate.core.mlmodel.buffer.DataIOParamProto.DataIOParam;
import com.webank.ai.fate.serving.core.bean.Context;
import com.webank.ai.fate.serving.core.bean.Dict;
import com.webank.ai.fate.serving.core.bean.FederatedParams;
import com.webank.ai.fate.serving.core.bean.StatusCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class DataIO extends BaseModel {
private static final Logger logger = LoggerFactory.getLogger(DataIO.class);
private DataIOMeta dataIOMeta;
private DataIOParam dataIOParam;
private List<String> header;
private String inputformat;
private Imputer imputer;
private Outlier outlier;
private boolean isImputer;
Expand All @@ -56,9 +60,12 @@ public int initModel(byte[] protoMeta, byte[] protoParam) {
this.outlier = new Outlier(this.dataIOMeta.getOutlierMeta().getOutlierValueList(),
this.dataIOParam.getOutlierParam().getOutlierReplaceValue());
}

this.header = this.dataIOParam.getHeaderList();
this.inputformat = this.dataIOMeta.getInputFormat();
} catch (Exception ex) {
ex.printStackTrace();
logger.error("init DataIo error",ex);
logger.error("init DataIo error", ex);
return StatusCode.ILLEGALDATA;
}
logger.info("Finish init DataIO class");
Expand All @@ -67,16 +74,43 @@ public int initModel(byte[] protoMeta, byte[] protoParam) {

@Override
public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) {
Map<String, Object> input = inputData.get(0);
Map<String, Object> data = inputData.get(0);
Map<String, Object> outputData = new HashMap<>();

if(logger.isDebugEnabled()) {
logger.debug("input-data, not filling, {}", data);
}

if (this.inputformat.equals(Dict.TAG_INPUT_FORMAT) || this.inputformat.equals(Dict.SPARSE_INPUT_FORMAT
)) {
if(logger.isDebugEnabled()) {
logger.debug("Sparse Data Filling Zeros");
}
for (String col: this.header) {
outputData.put(col, data.getOrDefault(col, 0));
}
} else {
outputData = data;
if(logger.isDebugEnabled()) {
logger.debug("Dense input-format, not filling, {}", outputData);
}
}

if (this.isImputer) {
input = this.imputer.transform(input);
outputData = this.imputer.transform(outputData);
}

if (this.isOutlier) {
input = this.outlier.transform(input);
outputData = this.outlier.transform(outputData);
}

return input;
/*
for (String col: data.keySet()) {
if (!output.containsKey(col)) {
output.put(col, data.get(col));
}
}*/

return outputData;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Collections;
import java.lang.Math;


public class HeteroFeatureBinning extends BaseModel {
Expand Down Expand Up @@ -56,38 +58,50 @@ public int initModel(byte[] protoMeta, byte[] protoParam) {
@Override
public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) {
HashMap<String, Object> outputData = new HashMap<>(8);
HashMap<String, Long> headerMap = new HashMap<>();
Map<String, Object> firstData = inputData.get(0);
if (!this.needRun) {
return firstData;
}

for (int i = 0; i < this.header.size(); i++) {
headerMap.put(this.header.get(i), (long) i);
}

for (String colName : firstData.keySet()) {
try{
if (! this.splitPoints.containsKey(colName)) {
try {
if (!this.splitPoints.containsKey(colName)) {
outputData.put(colName, firstData.get(colName));
continue;
continue;
}
Long thisColIndex = (long) this.header.indexOf(colName);
if (! this.transformCols.contains(thisColIndex)) {
// Long thisColIndex = (long) this.header.indexOf(colName);
Long thisColIndex = headerMap.get(colName);
if (!this.transformCols.contains(thisColIndex)) {
outputData.put(colName, firstData.get(colName));
continue;
}
List<Double> splitPoint = this.splitPoints.get(colName);
Double colValue = Double.valueOf(firstData.get(colName).toString());
int colIndex = 0;
for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) {
if (colValue <= splitPoint.get(colIndex)) {
break;
}
}
outputData.put(colName, colIndex);
}catch(Throwable e){
logger.error("HeteroFeatureBinning error" ,e);
int colIndex = Collections.binarySearch(splitPoint, colValue);
if (colIndex < 0) {
colIndex = Math.min((- colIndex - 1), splitPoint.size() - 1);
}
// for (colIndex = 0; colIndex < splitPoint.size(); colIndex ++) {
//
//
// if (colValue <= splitPoint.get(colIndex)) {
// break;
// }
// }
outputData.put(colName, colIndex);
} catch (Throwable e) {
logger.error("HeteroFeatureBinning error", e);
}
}
if(logger.isDebugEnabled()) {
logger.debug("HeteroFeatureBinning output {}", outputData);
if (logger.isDebugEnabled()) {
logger.debug("DEBUG: HeteroFeatureBinning output {}", outputData);
}

return outputData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public abstract class HeteroSecureBoost extends BaseModel {
protected List<String> classes;
protected int treeDim;
protected double learningRate;
protected boolean fastMode = true;

@Override
public int initModel(byte[] protoMeta, byte[] protoParam) {
Expand Down Expand Up @@ -93,13 +94,15 @@ protected int gotoNextLevel(int treeId, int treeNodeId, Map<String, Object> inpu
int fid = this.trees.get(treeId).getTree(treeNodeId).getFid();
double splitValue = this.trees.get(treeId).getSplitMaskdict().get(treeNodeId);
String fidStr = String.valueOf(fid);

if (input.containsKey(fidStr)) {
if (Double.parseDouble(input.get(fidStr).toString()) <= splitValue + 1e-20) {
nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid();
} else {
nextTreeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid();
}
} else {
logger.info("go missing dir");
if (this.trees.get(treeId).getMissingDirMaskdict().containsKey(treeNodeId)) {
int missingDir = this.trees.get(treeId).getMissingDirMaskdict().get(treeNodeId);
if (missingDir == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ private double sigmoid(double x) {
return 1. / (1. + Math.exp(-x));
}

private Map<String, Object> softmax(double[] weights) {
private boolean fastMode = true;

private Map<String, Object> softmax(double weights[]) {
int n = weights.length;
double max = weights[0];
int maxIndex = 0;
Expand Down Expand Up @@ -99,13 +101,36 @@ private double getTreeLeafWeight(int treeId, int treeNodeId) {
}

private int traverseTree(int treeId, int treeNodeId, Map<String, Object> input) {

while (!this.isLocateInLeaf(treeId, treeNodeId) && this.getSite(treeId, treeNodeId).equals(this.site)) {
treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input);
}

return treeNodeId;
}

private int fastTraverseTree(int treeId, int treeNodeId, Map<String, Object> input, Map<String, Object> lookUpTable) {

while(!this.isLocateInLeaf(treeId, treeNodeId)){
if(this.getSite(treeId, treeNodeId).equals(this.site)){
treeNodeId = this.gotoNextLevel(treeId, treeNodeId, input);
}
else{
Map<String, Boolean> lookUp = (Map<String, Boolean>) lookUpTable.get(String.valueOf(treeId));
if(lookUp.get(String.valueOf(treeNodeId))){
treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getLeftNodeid();
}
else {
treeNodeId = this.trees.get(treeId).getTree(treeNodeId).getRightNodeid();
}
}
if(logger.isDebugEnabled()) {
logger.info("tree id is {}, tree node is {}", treeId, treeNodeId);
}
}

return treeNodeId;
}

private Map<String, Object> getFinalPredict(double[] weights) {
Map<String, Object> ret = new HashMap<String, Object>(8);
Expand All @@ -121,9 +146,8 @@ private Map<String, Object> getFinalPredict(double[] weights) {
sumWeights[i % this.treeDim] += weights[i] * this.learningRate;
}

for (int i = 0; i < this.treeDim; i++) {
for (int i = 0; i < this.treeDim; i++)
sumWeights[i] += this.initScore.get(i);
}

ret = softmax(sumWeights);
} else {
Expand All @@ -139,14 +163,17 @@ private Map<String, Object> getFinalPredict(double[] weights) {

@Override
public Map<String, Object> handlePredict(Context context, List<Map<String, Object>> inputData, FederatedParams predictParams) {
if(logger.isDebugEnabled()) {
logger.debug("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams);
}

logger.info("HeteroSecureBoostingTreeGuest FederatedParams {}", predictParams);

Map<String, Object> input = inputData.get(0);
HashMap<String, Object> fidValueMapping = new HashMap<String, Object>(8);

ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false);
if(!this.fastMode){
// ask host to prepare data, if fast mode is not enabled
ReturnResult returnResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false);
}


int featureHit = 0;
for (String key : input.keySet()) {
Expand All @@ -155,12 +182,13 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec
++featureHit;
}
}
if(logger.isDebugEnabled()) {
logger.debug("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size());
}

logger.info("feature hit rate : {}", 1.0 * featureHit / this.featureNameFidMapping.size());
int[] treeNodeIds = new int[this.treeNum];
double[] weights = new double[this.treeNum];
int communicationRound = 0;

// start local inference
while (true) {
HashMap<String, Object> treeLocation = new HashMap<String, Object>(8);
for (int i = 0; i < this.treeNum; ++i) {
Expand All @@ -185,39 +213,70 @@ public Map<String, Object> handlePredict(Context context, List<Map<String, Objec

predictParams.getData().put(Dict.TREE_LOCATION, treeLocation);

try {
if(logger.isDebugEnabled()) {
logger.info("fast mode is {}", this.fastMode);
}

ReturnResult tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false);
try {
logger.info("begin to federated");

Map<String, Object> afterLocation = tempResult.getData();
if(logger.isDebugEnabled()) {
logger.debug("after loccation is {}", afterLocation);
boolean getNodeRoute = false;
ReturnResult tempResult;
if(this.fastMode){
getNodeRoute = true;
tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE, false);
}
for (String location : afterLocation.keySet()) {
treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue();
else{
tempResult = this.getFederatedPredict(context, predictParams, Dict.FEDERATED_INFERENCE_FOR_TREE, false);
}

if (afterLocation == null) {
logger.error("receive predict result of host is null");
throw new Exception("Null Data");

Map<String, Object> returnData = tempResult.getData();

if(this.fastMode && getNodeRoute){

if(logger.isDebugEnabled()){
logger.info("running fast mode, look up table is {}",returnData);
}

for(String treeIdx: treeLocation.keySet()){
int idx = Integer.valueOf(treeIdx);
int curNodeId = (Integer)treeLocation.get(treeIdx);
int final_node_id = this.fastTraverseTree(idx, curNodeId, fidValueMapping, returnData);
treeNodeIds[idx] = final_node_id;
}
}
else{
Map<String, Object> afterLocation = tempResult.getData();

if(logger.isDebugEnabled()){
logger.info("after location is {}", afterLocation);
}

for (String location : afterLocation.keySet()) {
treeNodeIds[new Integer(location)] = ((Number) afterLocation.get(location)).intValue();
}
if (afterLocation == null) {
logger.info("receive predict result of host is null");
throw new Exception("Null Data");
}
}

} catch (Exception ex) {
ex.printStackTrace();
logger.error("HeteroSecureBoostingTreeGuest handle error",ex);
return null;
}
}

for (int i = 0; i < this.treeNum; ++i) {
weights[i] = getTreeLeafWeight(i, treeNodeIds[i]);
}

if(logger.isDebugEnabled()){
logger.debug("tree leaf ids is {}", treeNodeIds);
logger.debug("weights is {}", weights);
logger.info("tree leaf ids is {}", treeNodeIds);
logger.info("weights is {}", weights);
}


return getFinalPredict(weights);
}
}
}
Loading

0 comments on commit 47a67a3

Please sign in to comment.