Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet npu #412

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Resnet npu #412

wants to merge 9 commits into from

Conversation

ShawnXuan
Copy link
Contributor

@ShawnXuan ShawnXuan commented Sep 24, 2024

这个分支固定了数据集,去掉了随机。

需要准备一个可以加载的初始化模型到 .../models/Vision/classification/image/resnet50/examples/checkpoints/init

比如在910b上可以

cd .../models/Vision/classification/image/resnet50/examples
cp -r /data1/home/xiexuan/git-repos/models/Vision/classification/image/resnet50/examples/checkpoints .

然后就可以运行 ./npu_eager.sh./npu_graph.sh

目前npu eager和cuda eager/graph都对齐了,但npu graph还没有对齐,输出的pred都是 0.001,需要深入调查

loss
tensor(6.9073, placement=oneflow.placement(type="npu", ranks=[0]), sbp=(oneflow.sbp.partial_sum,),
       dtype=oneflow.float32)
pred
tensor([[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        ...,
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],
       placement=oneflow.placement(type="npu", ranks=[0]), sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.float32)
label
tensor([582, 209, 272, 331, 768, 626, 838, 202, 333, 754, 435, 955, 853, 943,  40, 723,   3, 104,  51,  60, 118,
        762, 603, 353, 898,  69, 552, 824, 999, 217, 713, 334, 758, 818, 115,   1, 609, 238, 147, 446, 240, 455,
        442, 257, 206, 200, 911, 355, 684, 419], placement=oneflow.placement(type="npu", ranks=[0]),
       sbp=(oneflow.sbp.split(dim=0),), dtype=oneflow.int32)

@@ -45,6 +46,8 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
use_gpu_decode=args.use_gpu_decode,
device="cpu",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时用cpu解码

@@ -83,5 +83,30 @@ def forward(self, input, label):
# log_prob = input.softmax(dim=-1).log()
# onehot_label = flow.F.cast(onehot_label, log_prob.dtype)
# loss = flow.mul(log_prob * -1, onehot_label).sum(dim=-1).mean()
loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
#loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
loss = flow._C.cross_entropy(input, onehot_label.to(dtype=input.dtype), reduction='none')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

待验证训练是否收敛。

return loss.mean()

class oldLabelSmoothLoss(flow.nn.Module):
Copy link
Contributor Author

@ShawnXuan ShawnXuan Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是flowvision里面的loss。
需要dim_gather

Copy link
Contributor

@Flowingsun007 Flowingsun007 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_cross_entropy和dim_gather应该不难开发,我们可以列到npu开发计划里,后面等开发完成再试试

Copy link
Contributor

@Flowingsun007 Flowingsun007 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下npu的dim_gather已经支持了:oneflow_npu/kernels/dim_gather_kernel.cpp,应该再开发一个softmax_cross_entropy就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax_cross_entropy 我开发了一个,可能反向还有问题。不过softmax_cross_entropy没有和torch对应,我倾向于开发和torch兼容的算子,所以选了flowvision的方案,不用softmax_cross_entropy。

我回头试试 dim_gather

@@ -83,5 +83,30 @@ def forward(self, input, label):
# log_prob = input.softmax(dim=-1).log()
# onehot_label = flow.F.cast(onehot_label, log_prob.dtype)
# loss = flow.mul(log_prob * -1, onehot_label).sum(dim=-1).mean()
loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
#loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前npu不支持

@ShawnXuan
Copy link
Contributor Author

进展:
查到了 Vision/classification/image/resnet50/models/resnet50.py 里面的 _make_layer 最后返回的是 nn.Sequential(*layers),就没法深入对比了。

    def _forward_impl(self, x: Tensor) -> Tensor:
        if self.pad_input:
            if self.channel_last:
                # NHWC
                paddings = (0, 1)
            else:
                # NCHW
                paddings = (0, 0, 0, 0, 0, 1)
            x = flow._C.pad(x, pad=paddings, mode="constant", value=0)
        x = self.conv1(x)
        if self.fuse_bn_relu:
            x = self.bn1(x, None)
        else:
            x = self.bn1(x)
            x = self.relu(x)
        x = self.maxpool(x)
        # graph模式时前面都是对齐的,能够正常输出,说明算子没有问题
        x = self.layer1(x)
        r = x # 到这里返回的值都是 0,深入 layer1发现,最后是nn.Sequential,不好对比了,可能需要构造一个nn.Sequential的graph 单测
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = flow.flatten(x, 1)
        x = self.fc(x)

@ShawnXuan
Copy link
Contributor Author

09.27 进展
nn.Sequential 测试通过,测试脚本如下:

import unittest
import numpy as np
from collections import OrderedDict

import oneflow as flow
import oneflow_npu
from oneflow.test_utils.test_util import GenArgList


def create_model(shape, device, dtype):
    model = flow.nn.Sequential(
        flow.nn.Linear(shape[-1], shape[-1]),
        flow.nn.ReLU(),
        flow.nn.Linear(shape[-1], shape[-1]),
        flow.nn.Softmax(dim=-1)
    ).to(device=device, dtype=dtype)
    return model


class EvalGraph(flow.nn.Graph):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def build(self, x):
        logits = self.model(x)
        return logits


class TrainGraph(flow.nn.Graph):
    def __init__(self, model):
        super().__init__()
        self.model = model
        param_group = {"params": [p for p in model.parameters() if p is not None]}
        optimizer = flow.optim.SGD([param_group], lr=0.1, momentum=0.9)
        self.add_optimizer(optimizer)

    def build(self, x):
        logits = self.model(x)
        loss = logits.sum()
        loss.backward()
        return loss, logits


def _test_sequential_forward(test_case, shape, device, dtype):
    model_cpu = create_model(shape, "cpu", flow.float)
    model_npu = create_model(shape, device, dtype)
    model_npu.load_state_dict(model_cpu.state_dict())

    arr = flow.rand(*shape)
    arr_npu = arr.to(device=device, dtype=dtype)

    out_cpu = model_cpu(arr.to(dtype=flow.float))
    out_npu = model_npu(arr_npu)

    tol = 1e-3 if dtype == flow.float16 else 1e-5

    test_case.assertTrue(np.allclose(out_cpu.numpy(), out_npu.cpu().numpy(), tol, tol))

    graph = EvalGraph(model_npu)
    graph_out = graph(arr_npu)

    test_case.assertTrue(np.allclose(graph_out.cpu().numpy(), out_cpu.numpy(), tol, tol))


def _test_sequential_grad(test_case, shape, device, dtype):
    model_cpu = create_model(shape, "cpu", flow.float)
    model_npu = create_model(shape, device, dtype)
    model_npu.load_state_dict(model_cpu.state_dict())
    graph = TrainGraph(model_npu)

    np_arr = np.random.rand(*shape).astype(np.float32)
    arr = flow.tensor(np_arr, dtype=flow.float, requires_grad=True)
    arr_npu = flow.tensor(np_arr, dtype=dtype, device=device, requires_grad=True)
    print("*"*100)
    print("arr", arr)
    print("arr_npu", arr_npu)

    out_cpu = model_cpu(arr.to(dtype=flow.float))
    out_cpu.sum().backward()

    loss, out_npu = graph(arr_npu)

    print("out_cpu", out_cpu)
    print("out_npu", out_npu)

    for param_cpu, param_npu in zip(model_cpu.parameters(), model_npu.parameters()):
        print("cpu", param_cpu)
        print("npu", param_npu)
    #tol = 1e-3 if dtype == flow.float16 else 1e-5
    #test_case.assertTrue(np.allclose(arr.grad.numpy(), arr_npu.grad.cpu().numpy(), tol, tol))


@flow.unittest.skip_unless_1n1d()
class Test_Sequential_Graph(flow.unittest.TestCase):
    def test_sequential_forward(test_case):
        arg_dict = OrderedDict()
        arg_dict["test_fun"] = [
            _test_sequential_forward,
        ]
        arg_dict["shape"] = [(2, 4), (2, 3, 4)]
        arg_dict["device"] = ["npu"]
        arg_dict["dtype"] = [flow.float32, flow.float16]
        for arg in GenArgList(arg_dict):
            arg[0](test_case, *arg[1:])

    def test_sequential_grad(test_case):
        arg_dict = OrderedDict()
        arg_dict["test_fun"] = [
            _test_sequential_grad,
        ]
        arg_dict["shape"] = [(2, 4), (2, 3, 4)]
        arg_dict["device"] = ["npu"]
        arg_dict["dtype"] = [flow.float32]
        for arg in GenArgList(arg_dict):
            arg[0](test_case, *arg[1:])


if __name__ == "__main__":
    unittest.main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants