Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.0支持纵向的lstm吗? #5707

Open
FancyXun opened this issue Sep 6, 2024 · 4 comments
Open

2.0支持纵向的lstm吗? #5707

FancyXun opened this issue Sep 6, 2024 · 4 comments

Comments

@FancyXun
Copy link

FancyXun commented Sep 6, 2024

我看fate对torch的nn 有些封装,包括Sequential这一类,同时也看到了lstm的模型,但怎么使用呢?lstm的输出有是个tuple,没法直接add 进Sequentia吧?

@FancyXun
Copy link
Author

FancyXun commented Sep 6, 2024

bottom_model=Sequential(
nn.Linear(10, 10),
nn.LSTM(input_size=10, hidden_size=10, batch_first=True),
nn.Linear(10, 10),
),

这种定义肯定有问题吧,LSTM的输出是个tuple

@talkingwallace
Copy link
Contributor

请问一下,你看的是哪个教程呢?

@FancyXun
Copy link
Author

FancyXun commented Sep 9, 2024

多谢答复,我使用的是fate 2.0里面自带的nn例子,这是我基于nn的例子写了一个,你看下有什么问题呢。

#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import argparse
from fate_client.pipeline.utils import test_utils
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.nn.torch import nn, optim
from fate_client.pipeline.components.fate.nn.torch.base import Sequential
from fate_client.pipeline.components.fate.hetero_nn import HeteroNN, get_config_of_default_runner
from fate_client.pipeline.components.fate.psi import PSI
from fate_client.pipeline.components.fate.reader import Reader
from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments
from fate_client.pipeline.components.fate import Evaluation
from fate_client.pipeline.components.fate.nn.algo_params import FedPassArgument



class LSTMModel(Sequential):
    def __init__(self, input_size, hidden_size, num_layers, num_classes ):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # 设置初始状态
        # h0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))
        # c0 = Variable(torch.zeros(num_layers, batch_size, hidden_size))

        # 前向传播
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :]) 
        return out






def main(config="../../config.yaml", namespace=""):
    # obtain config
    if isinstance(config, str):
        config = test_utils.load_job_config(config)
    parties = config.parties
    guest = parties.guest[0]
    host = parties.host[0]
    arbiter = parties.arbiter[0]

    pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)

    reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host))
    reader_0.guest.task_parameters(
        namespace=f"experiment{namespace}",
        name="breast_hetero_guest"
    )
    reader_0.hosts[0].task_parameters(
        namespace=f"experiment{namespace}",
        name="breast_hetero_host"
    )
    psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"])

    training_args = TrainingArguments(
            num_train_epochs=1,
            per_device_train_batch_size=16,
            logging_strategy='epoch'
        )

    guest_conf = get_config_of_default_runner(
        bottom_model= LSTMModel(10, 10, 3, 10),
        top_model=Sequential(
            nn.Linear(10, 1),
            nn.Sigmoid()
        ),
        training_args=training_args,
        optimizer=optim.Adam(lr=0.01),
        loss=nn.BCELoss()
    )

    host_conf = get_config_of_default_runner(
        bottom_model=nn.Linear(20, 20),
        optimizer=optim.Adam(lr=0.01),
        training_args=training_args,
        agglayer_arg=FedPassArgument(
            layer_type='linear',
            in_channels_or_features=20,
            hidden_features=20,
            out_channels_or_features=10,
            passport_mode='single',
            passport_distribute='gaussian'
        )
    )

    hetero_nn_0 = HeteroNN(
        'hetero_nn_0',
        train_data=psi_0.outputs['output_data']
    )

    hetero_nn_0.guest.task_parameters(runner_conf=guest_conf)
    hetero_nn_0.hosts[0].task_parameters(runner_conf=host_conf)

    hetero_nn_1 = HeteroNN(
        'hetero_nn_1',
        test_data=psi_0.outputs['output_data'],
        predict_model_input=hetero_nn_0.outputs['train_model_output']
    )

    evaluation_0 = Evaluation(
        'eval_0',
        runtime_parties=dict(guest=guest),
        metrics=['auc'],
        input_data=[hetero_nn_1.outputs['predict_data_output'], hetero_nn_0.outputs['train_data_output']]
    )

    pipeline.add_tasks([reader_0, psi_0, hetero_nn_0, hetero_nn_1, evaluation_0])
    pipeline.compile()
    pipeline.fit()

    result_summary = pipeline.get_task_info("eval_0").get_output_metric()[0]["data"]
    print(f"result_summary: {result_summary}")


if __name__ == "__main__":

    parser = argparse.ArgumentParser("PIPELINE DEMO")
    parser.add_argument("--config", type=str, default="../config.yaml",
                        help="config file")
    parser.add_argument("--namespace", type=str, default="",
                        help="namespace for data stored in FATE")
    args = parser.parse_args()
    main(config=args.config, namespace=args.namespace)

fate_client能编译通过,但是在fate server端回报错:
Screen Shot 2024-09-09 at 15 37 35

lstm的输出是一个tuple,不是一个tensor,我还特点在自定义的model的forward阶段处理了下,这个模型使用torch自带也能预测出结果。不知道fate是不是对torch的一些中间接口改了,导致模型发送到server阶段后,重新加载,丢失了一些信息。 或者有没有一个可以用纵向的lstm用作分类的例子呢?lstm作为bottle model,或者top model都可以。参考官方自带的nn例子总有bug,不知道是不是我使用的不对,还请多指教。@talkingwallace

@talkingwallace
Copy link
Contributor

这个是一个基于pipeline的例子,可以参考下ml直接运行的例子呢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants