-
Notifications
You must be signed in to change notification settings - Fork 24
/
test_hybrid_attn.py
232 lines (194 loc) · 6.71 KB
/
test_hybrid_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from yunchang import (
AsyncLongContextAttention,
LongContextAttention,
set_seq_parallel_pg,
EXTRACT_FUNC_DICT
)
import torch
import torch.distributed as dist
from flash_attn import flash_attn_func
from yunchang.kernels import FlashAttentionImpl
from test_utils import attention_ref
def log(msg, a, rank0_only=False):
world_size = dist.get_world_size()
rank = dist.get_rank()
if rank0_only:
if rank == 0:
print(
f"[Rank#0] {msg}: "
f"max {a.abs().max().item()}, "
f"mean {a.abs().mean().item()}",
flush=True,
)
return
for i in range(world_size):
if i == rank:
if rank == 0:
print(f"{msg}:")
print(
f"[Rank#{rank}] "
f"max {a.abs().max().item()}, "
f"mean {a.abs().mean().item()}",
flush=True,
)
dist.barrier()
# test it with:
# torchrun --nproc_per_node=4 test/test_hybrid_attn_v2.py
if __name__ == "__main__":
torch.random.manual_seed(0)
use_bwd = True
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == 4, f"torchrun --nproc_per_node=4 test/test_hybrid_attn.py"
# Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
dtype = torch.bfloat16
device = torch.device(f"cuda:{rank}")
batch_size = 2
seqlen = 1024
nheads = 4
d = 128
dropout_p = 0
causal = True
deterministic = False
assert seqlen % world_size == 0
assert d % 8 == 0
ring_impl_type = "basic" # You can change this to "basic" or "zigzag" if needed
# Prepare inputs
q = torch.randn(
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
)
k = torch.randn(
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
)
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
dist.broadcast(q, src=0)
dist.broadcast(k, src=0)
dist.broadcast(v, src=0)
dist.broadcast(dout, src=0)
# prepare process group for hybrid sequence parallelism
use_ring_low_dim = True
sp_ulysses_degree = 2
sp_ring_degree = world_size // sp_ulysses_degree
print(
f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}"
)
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
# Use EXTRACT_FUNC_DICT to shard the tensors
local_q = EXTRACT_FUNC_DICT[ring_impl_type](
q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_q.requires_grad = True
local_k = EXTRACT_FUNC_DICT[ring_impl_type](
k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_k.requires_grad = True
local_v = EXTRACT_FUNC_DICT[ring_impl_type](
v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_v.requires_grad = True
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, attn_type=FlashAttentionImpl.FA)
if rank == 0:
print("#" * 30)
print("# ds-ulysses forward:")
print("#" * 30)
# common test parameters
window_size=(-1, -1)
alibi_slopes, attn_bias = None, None
dropout_mask = None
print(f"before usp attn forward: {local_q.shape} {local_k.shape} {local_v.shape}")
# usp attn forward
local_out = usp_attn(
local_q,
local_k,
local_v,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
# extract local dout
local_dout = EXTRACT_FUNC_DICT[ring_impl_type](
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
if rank == 0:
print("#" * 30)
print("# ds-ulysses backward:")
print("#" * 30)
# usp attn backward
if use_bwd:
local_out.backward(local_dout)
dist.barrier()
if rank == 0:
print("#" * 30)
print("# local forward:")
print("#" * 30)
# reference, a local flash attn
out_ref, _, _ = flash_attn_func(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out_pt_ref, attn_pt_ref = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
if rank == 0:
print("#" * 30)
print("# local forward:")
print("#" * 30)
if use_bwd:
out_ref.backward(dout)
dist.barrier()
# check correctness
# When checking correctness, use EXTRACT_FUNC_DICT for reference outputs
local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](
out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
local_out_pt_ref = EXTRACT_FUNC_DICT[ring_impl_type](
out_pt_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("local (rank) out", local_out, rank0_only=True)
log("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out)
# log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref)
torch.testing.assert_close(local_out, local_out_ref, atol=1e-2, rtol=0)
# torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
if use_bwd:
local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type](
q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dq", local_q.grad)
log("dq diff", local_dq_ref - local_q.grad)
local_dk_ref = EXTRACT_FUNC_DICT[ring_impl_type](
k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dk", local_k.grad)
log("dk diff", local_dk_ref - local_k.grad)
local_dv_ref = EXTRACT_FUNC_DICT[ring_impl_type](
v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dv", local_v.grad)
log("dv diff", local_dv_ref - local_v.grad)
if dist.is_initialized():
dist.destroy_process_group()