From 34eb9d424853df35b33a1bc6859be48b3234bd5f Mon Sep 17 00:00:00 2001 From: Dariusz Mroz Date: Fri, 8 Nov 2024 11:56:26 +0100 Subject: [PATCH] test_mutable_cmdlist: add tests for kernel mutation (#102) --- .../kernels/test_mutable_cmdlist.cl | 10 + .../kernels/test_mutable_cmdlist.spv | Bin 9888 -> 11140 bytes .../src/test_mutable_cmdlist.cpp | 662 ++++++++++++++++-- 3 files changed, 612 insertions(+), 60 deletions(-) diff --git a/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.cl b/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.cl index 18deadf3..97c79e99 100644 --- a/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.cl +++ b/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.cl @@ -11,11 +11,21 @@ kernel void addValue(global int *inOut, int val) { inOut[gId] += val; } +kernel void subValue(global int *inOut, int val) { + const int gId = get_global_id(0); + inOut[gId] -= val; +} + kernel void mulValue(global int *inOut, int val) { const int gId = get_global_id(0); inOut[gId] *= val; } +kernel void divValue(global int *inOut, int val) { + const int gId = get_global_id(0); + inOut[gId] /= val; +} + kernel void testGlobalSizes(global int *inOut, global int *globalSizes) { int gIdX = get_global_id(0); diff --git a/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.spv b/conformance_tests/core/test_mutable_cmdlist/kernels/test_mutable_cmdlist.spv index 59b8605752932764abcddd0f1bf3f58fa5f64302..115ceb18d6459ff6ef0ab0f4ce65d0ec3e0f9b2d 100644 GIT binary patch literal 11140 zcmaKy2Yg;t8OBeW?@Ko=qpUJg3T2gO(iW&#WJxUo6_C~erh%j-G)+QA8H!e|AP5c+ zWG~r6Q2}LIL2xky1Z0CCs0auO0*d$d)a=(k0-^+R4=N;!g^L}5y)Qy}rDpW>< zH9~#Zazx0UHN(hI5u&Zewm!Dw)Q%I2>w=#4uHE-;?CWm}im#J*$Yf#|rEDSmQ_|Yj zwy?Expj|M_W!t}{_YE9UY8#>C%Lh73ZS}IXb*w1Z#$^4CldZqKuYX==_aUvF2Xq|M z-k0@BEfZBW?WE|lclV)oQMDnAmW+{%l}yzC@tr&@--0De`r7+NXE~XNQOdc2e3rHM zcC~l5wDvA-=|8%sz0qTC?C9#>W`_PZBlq2~*0<$|f!5BUfEn`4eW%p=HhR8_c{i%{ z&GQZh4CURp*0<4fUCg^lt#6)pFkmR}rnSC}UZ-N-hFae|?_hxQP6$(LeH#mFJeYo) zY@Wx5&1z%v{KY_ZGb&67(`x;*K8E(lur3rcK5SkaGt@^C$Vl+tqSikdN4~f6UUL8G zwf;l-^9>$@X6WScUkJ~fTG8I*bcWE<($~|`yP{>cfsW4p`CWO${H_(+_Wd2*UGsHh z<(HqEQeTcw<}*(9*W}AaSV*UQW3zKq$hWAwciGb3?tz||GV7PK_FjGzM=s7{!Drur z<%8+^WX^JN77Kpq49TY&CdtR=GLNh1ehZc;3)liAlr@9tV*eLaDWwxfD_yO+ye4eLv9?d@$nT4e517dC`zYs>6Ygnc}7 zdnIfnkk+%Ybl0BV~9FHSXN^)w0lIuOk2X;z!oGgG?jnI)Xov%U1R zo*mGW`Por=TF)Hm#d>xUryA6EeV8j;xIVI)b|%l@9;${9Nnbkb{Zk3Mh#RV{KI|&J zv$dOmo{V*V zv@@ry>t4~$e4_6i?aU|oKGDvcqA!T{*-`&kv@^%p_m6hw82x}~XO7VijCST2ePOgW zX~8*npG05Yli5C%<0jYk>8R&@SHr>5hp$sLv?N@Uis1aUrko0)pF_mu`>PTT&2f`! z;y#qEnLUUvcVB*1D`82Fn>~L^rKfx6FzoTHH>+^QczBG-=d?a7lg|DbdNt&JTqx=O zSuVXZ`D{Y2E^+Dp>BgsI|MV33=IvEWeK?|qpOf*uMSf;;#`o3mv$HXAPp12DAlC=? z;gQnQeaOX^?!#lE-Sha_X!lwl8|~hQpO1F$!+yDi1sEG$h@CudtB3#Y*&HUPqtkJqJP776^Q;#+f^X?DYj>Ov=UCWJ=>p^ zaGLEZ5dYJoolS{;MzpgL(a*F!tFJDcg&xn***R|Z+?^A3&&|2HJ~Q9%m2h6vJ@@BF zJ)hfZxInu1xS;y4h6@vJW;TXuxG3eaanyzHiu1XfRsUldisJd4GIA*Jr!dYoqS{`vd9e zJ@7;9@gA6+tw|+Z7i038M}-@t54}gy?-cdn$Fk+`j-2n4&N}}@ZF>H1lwNlJZz}TH znpDEg;?r|`ORmRrTOV#M#%Y6bZ`@YH^FFGEpAO|~&gZfkR_B~@X3Le;eY?2yKDYxt zc`w}=?ViuOqTTEL%V_s{|0>$^^{xxQL67_Wo*c8(S>N~NI!l@9_oM6co6q(G=&FSm z4P$yxy4OxHi+KniVpIU@@Gv?tDu9?rq&uc5p7}=|qe_T*%rTlMVjfRp;`%&^&+?wZ zchS?i&Z{KL@odz+ub)F#UCmj1H9Rlfbu??l^!oxns+9xI`|L%0;<@^hZJKy||AMZ3 zy8fkC!%O(aeLORc{bk2wW3Pl)@M&wC>6%`(J&yY|baKXVzmC6i#&Q3{F|p=1@X`76 zd3!V3XT>_-l0NjVtP5{T_vgQN3%36xwvi)-$JI%9+^B+WbYdGn%vKE(q!&Nity%C{ zOZxElcYRoQFmC8v&(rtLs%qR?Qu^Jc@a0M=!aS`bG3#cS?|S@OOvt`B%!~iN@t*nO zmoMk3Px64tv#Er-MyV~yQ~I6H`G|2{VDe3o%+q(@(P~Tb4StdH_d@5w-?_l#+Cah> z$Ef{(x%lqrJoq{fm^>RwsA;U)k~}rvFP#g2=K_;!eF-&zOL7hS4(dGkIu96`Dw(HQ zpQM&NIs*&+Z)Q?u%fCAgteN{UulS|kAMLk=fc>UNKfQZv@^QxOr{k67OTTZ4Qop0K z?>+Q2^v+_n`Y>m8@tHc$Q^6d#;lHkJ!@#yUHvHM9sBg-Z0htCks&1VSX zOA>r(f-g((wo~kfeDr0Vno^xRDb=}GQhiK8=ZvRz&M7+IsHl~*%0ultXoS?Lyl{GDi)P{iJ~>OR(`Aq45yU9Rc4f9iR0gx2Iz0i@S&Z;K@%6 zcO*Ri%>7FF!gJ?gqo$q3;R9#CcQvvdB_1313r{+>jifV9@@*_(Y@4X{*m!r8jjcgE zHXe^{s&IqEW7|v^&SRSkm+vdXXuK7!GY+H!KMoo+jA2|CJ$2MIyo^)(mO83~d z5@u|iTaRrU;j*!9D;^t<$2LP4-ecQN7|vsB6sA9TkL`WJ@UD4P92<8uHfmyQ_`un( zI5zHao^)*bNH;h(eU!__HcPF?He0xCY}<>+#^bT+qn!50(AahqhV$6w2-6?D$MykX zc-MS<9NSLfuu&6Z!w1fO#j(wmjfZ=|=WXW%e@K{h+eNL%_F>_&vF#}y8;{2}PZ-{@ z9}zB#oi85S2#I6&5{7r|M}^B`_ZE+hF*|= z2{T{zMc++gU$A{YEzJDb7kzh$eZk&)2Me zc#r8YVa@_P>q)K-3B2odzQd!9c*cbfoc)UDZkcR6>ACBa?lZVtm^JQF>%Fu>c$9>k z4HEC+BZaZ?cn=>X4DZ;Zh09`(5sz(z#Ic_hhIj1egv(-&6_1U29Q%1;c*lN0c(kNg z^Gb2pcpUph;gu4{eo442_RHe2r8TdT4)55n2$#hkCmvf`^YPN*9s5<`G0~?%t>^g! zVQf6E`9xuO$9_$?EcPVv*wUI$mJaXOuM3yOenULAwB~P0hj;9^gvUmo2DPsF6k%*U zuK849c*lNQxGeTG@z~OuPnQnw*zX9J#hxJ^TUzs((&2p%oh8gZT_oXtL+_RhNZ9*e z`<^Y#zO*lTkHo%U`<^4rKD00T5fb}??R%~;YiM8eUWt9d_B~ISHM1{zpTxdk`<^e% z8rc`UUt(Xd@0kmPIqwHaxI54}3(Q?bV!sQ8@%xMfKXmqe(eEO0_$`*;ht3|iA9t+l zp!MDPU14fwjm=#w49}i1cZo3b0O$MfQek|tnY&CFp80j`<-+ic&)gNl@T(H;N?~~F zw%=94tZV9bwRHG2-!;o8JpR5Qg_U{7|@40?#>^ zp!P=+c+au(T^DV{vljTk*{}FsxIs3a>m_~{{8*Ut@DsH@cRv%ZO32e7@p->p7#okz z`yImYj{UiCS?rzSv5k;8_AX&~$NoaNEcTb;u~Cm>eGw~Z?YWcN<^NxO_Y z!siqnA97VB?&EjDWqmv(9$PxU4@-x4ACCyrhtD`VKE*zMFRrYQN5x}H*Woeg@b2Sr zVft7sVNast<36|_Ja@F7=O=|(zb7QjGy1KP)e_bZY~QDaS-+v zS-)o__@e(*vRZ;K*mL^_;j(c*CmtJ*$N9W4<7B;v=bZjg0`IZWpUVC-z&m!gM@xxxsIf4b&ljh`i66fSM5^`45I%idQx?=FdUd+kgX7Z#t`F+JX`E7-q zW7In5Sm7qc;D^1KlfUWY`Jlw}$?r18{3b)rHPkxiMB%ye#}9ikCw~LVljh`i9OvXW t9CEI$);ZS^UZxoQuorXkx1v1POPq(_eaN$(TK?jKwO(KN2DR9n<9{O`+z$W% literal 9888 zcmaKx37nnN8OBe{y^~$+)Y49cBZOFEHzsD1j7%nzSwfUb?fVuxCHAG1 z+7(6Zl%f=+))p;Vv=ptj&;R?*GjsCI)i1v{=XsxZIqx~=J>R`Isaj%FRE1n=`GcZTHs0Te`AXa*ojKjGsiG$#V|2i)IaB zkYuoAh-8HR#`kkcU%n~Rr+2k;g?;RsBMD?B_z&pspR6N4t9+K+ ze~JG7efjf-i9w^)k^Bkat%W6>HC-D*V`EoGYv;Vit$SMAx+k^g5tG{Isg1f@=d@2! zE6XqM!-~FaZRRsn_4mt{t+0?z`G#bDS;)87oX%M@I_LCs#FW{*^zOI$Q5?DG?SjwF zJ+ph$b;+FNqPGiv=?=-K9F~%g`#g^;>E4IX+CHU+OH$SjE~9Z?Q=4!pEGyj7-rYGr zvz5Z~8g&EPG-UP_(C4(zv%aE0Yx7|popWZ(UJffsZ|dx9nlCcY5Amh24taX_ zs2tXnK4Y=>RVj=T*H>G0SWkLeQ+so5+sv-o?#}F`$Q-%~i}Q73vpLom-#WdvduB^} zrq||rOMBJ_rLckYbggyh$y&!mJ3KXyjrRJeZy4?PVc#g)@k1XU?c<`pNwgD(eY0pM z4t?`z-#F@9Mmu@1Pe50Fbt;^`+6H}gM`qhL$91l4yQt@LmqUZ}#m}i6CM8^*ilD!C zOgR<8Jd?%c=UECn<+#o@?VRiN>?OWaH0Hfn3cKXE-ut(!^t5+&!ybFNUWK#9Jz`A0 zr`2Ik>GV(Et1tJPO0J}T_LAO~e3p@GZ*gh=?1N8*{@J(4*Kep=s>6Q$`0*LPe~}+k zpYaFu|Prp7bH_RP5eI-k;cg)-$5reK<4P-G_%nyZdlf zw7U;yN4xv5J=)!e9ntPHpBwFYAC^K_wAZOX_PyKoc&77gSAp0MwOs|G&$nF#q91O% z3PeA`_N+%s;Yi!F{w#%~Y*&H!FNk)U68-3CrxDSQu|2D=DjbU*@6mBNuJ_&@A9e4| z3AsKd@ApzTG3ws?lcJvQZ8@AQ-90X-`OD#ygd3Bsp&U+4xojO(;WTl+ccUmF8g)@oIJ_AbOtQ?PZf1DkCUE{*2dyVHr-TiTH)ZHKFMcsRFe$?F`7f9FKW8EJY z=K46-dQsH#T1(+#>FImm672CkFfR7Yr774Tw z8q@oKrSwYo|EeM%`|4`(>Ak%s*W%wXQLoN?W+f6D}qC6T-YR?EjQ7cYj5B5l1e@ z8tKF_7y7cY42%4d8p*+0M+%M5hz;J{S_ucgoZ`3EI8s?JOKh9QdC13H)~PAgd7q>@ z?}StzT+q3nshvB9&aWD3<(~0S`$k$JHR@_*m^CaT`BF76t8sJf5PPtkggV)y8p*EG zz{^Xp@yykFi06L8cS^@+#mG%m40G|`U_N;A6T|xr9)I?Jfqdb4Uty!BEyUpiXTNW> zvW+EP8}A05bZx6jXPx9*O~Tr~tI=!YycVhd+Zw{m2k*6g zUl`ss9~;-UrZ{ZW#M~!Huu`|SDqaMf36oz+f ztMDL6vF1a>VdHV^EMa)Zwh32?oh=?)T64Q}c*o8Wt`yrL9$Q-TT@oCmvf`^P$q=9ebE?rP%r6v86R1E*;*nKNKDkeQGqi<|Bl$ z@wnzAh2b6hBjHN1M~TOl*1SMEykmbXTq*Ww@z~OukC6`V_t3Gz(f5JXHeE zJs7U>ClYw?vGZLTZNzgH_`un(__?r1HlE8Q{#>|RnEP;rM&G-i3YR71sgd};-zbcY z$M^jvVR*;>Ot@0)&El~QkT~`hVR*;>T)0x~t>Up!k7I8WhWDMgLzr4_mvG_@ZAeu`k&5|3Y}Eq_}VQio?d^`tK8l_qpFM z%vt!JqT@rZlEicTQn=C_4~WN>?(c)r;XTJg!pz}2j{cyeILEKVRhr{r@!0V9Ivx>* z_Z*K3Gl$;~==gXJ-Vfe8M(^|E!kph@680JWTFDI(&M(LLch@Jyaehxo@I}8)V&5F+ zzE6wi{GO8Fi+;VtzF_a|uZ1hE`x)`rc)ZSMg<0oB3Gu@<{zd}twfc@f7j4{IY|l&J z$m@N4A@bZ~a-chp>-(*^O7*=c9@_wk>w8I<`Z#OX_d8*D=XHHAi^JynUJ-^Puj_kN z7@qpbf$lu6@Au*=)%TisZ1}sr*M;3DuI~@x;GNg?y&(>p>w8lej=Zk#En#@-BL}+k z&~rX-#^QVNZE-v!Bz`Zx6EWxdu7n=^qsH+%>tP!IBw0#=ANFF-_r&p}Ip3E~esX>w zA?KepI_F=6SCQa{y_oYuaXf2DT<1p#{#cltf7R%me-mC?f*( + lzt::allocate_host_memory(bufferSize * sizeof(int32_t))); + for (size_t i = 0; i < bufferSize; i++) { + inOutBuffer[i] = initBufferVal; + } + uint32_t groupSizeX = 0; + uint32_t groupSizeY = 0; + uint32_t groupSizeZ = 0; + + ze_kernel_handle_t addKernel = lzt::create_function(module, "addValue"); + ze_kernel_handle_t mulKernel = lzt::create_function(module, "mulValue"); + std::vector kernels{addKernel, mulKernel}; + + lzt::suggest_group_size(addKernel, bufferSize, 1, 1, groupSizeX, groupSizeY, + groupSizeZ); + lzt::set_group_size(addKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(addKernel, 0, sizeof(void *), &inOutBuffer); + lzt::set_argument_value(addKernel, 1, sizeof(scalarVal), &scalarVal); + + uint64_t mutableKernelCommandId = 0; + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &mutableKernelCommandId)); + + ze_group_count_t groupCount{bufferSize / groupSizeX, 1, 1}; + lzt::append_launch_function(mutableCmdList, addKernel, &groupCount, nullptr, + 0, nullptr); + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], initBufferVal + scalarVal); + } + + // Mutate kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandKernelsExp( + mutableCmdList, 1, &mutableKernelCommandId, &mulKernel)); + + // Mutate all invalidated data + ze_mutable_group_count_exp_desc_t mutateGroupCount{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount.commandId = mutableKernelCommandId; + mutateGroupCount.pGroupCount = &groupCount; + ze_mutable_group_size_exp_desc_t mutateGroupSize{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize.commandId = mutableKernelCommandId; + mutateGroupSize.groupSizeX = groupSizeX; + mutateGroupSize.groupSizeY = groupSizeY; + mutateGroupSize.groupSizeZ = groupSizeZ; + mutateGroupSize.pNext = &mutateGroupCount; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg.commandId = mutableKernelCommandId; + mutateBufferKernelArg.argIndex = 0; + mutateBufferKernelArg.argSize = sizeof(void *); + mutateBufferKernelArg.pArgValue = &inOutBuffer; + mutateBufferKernelArg.pNext = &mutateGroupSize; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg.commandId = mutableKernelCommandId; + mutateScalarKernelArg.argIndex = 1; + mutateScalarKernelArg.argSize = sizeof(scalarVal); + mutateScalarKernelArg.pArgValue = &scalarVal; + mutateScalarKernelArg.pNext = &mutateBufferKernelArg; + mutableCmdDesc.pNext = &mutateScalarKernelArg; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp( + mutableCmdList, &mutableCmdDesc)); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], (initBufferVal + scalarVal) * scalarVal); + } + lzt::free_memory(inOutBuffer); + lzt::destroy_function(addKernel); + lzt::destroy_function(mulKernel); +} + +TEST_F( + zeMutableCommandListTests, + GivenMutationOfKernelInstructionWithSignalEventWhenCommandListIsClosedThenKernelIsReplacedAndEventRemainsUnchanged) { + if (!CheckExtensionSupport(ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1) || + !kernelInstructionSupport) { + GTEST_SKIP() << "ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION not " + "supported"; + } + const int32_t bufferSize = 16384; + const int32_t initBufferVal = 1111; + const int32_t scalarVal = 3333; + int32_t *inOutBuffer = reinterpret_cast( + lzt::allocate_host_memory(bufferSize * sizeof(int32_t))); + for (size_t i = 0; i < bufferSize; i++) { + inOutBuffer[i] = initBufferVal; + } + + lzt::zeEventPool eventPool; + const uint32_t eventsNumber = 1; + std::vector events(eventsNumber, nullptr); + eventPool.InitEventPool(context, eventsNumber, + ZE_EVENT_POOL_FLAG_HOST_VISIBLE); + eventPool.create_events(events, eventsNumber); + + uint32_t groupSizeX = 0; + uint32_t groupSizeY = 0; + uint32_t groupSizeZ = 0; + + ze_kernel_handle_t addKernel = lzt::create_function(module, "addValue"); + ze_kernel_handle_t mulKernel = lzt::create_function(module, "mulValue"); + std::vector kernels{addKernel, mulKernel}; + + lzt::suggest_group_size(addKernel, bufferSize, 1, 1, groupSizeX, groupSizeY, + groupSizeZ); + lzt::set_group_size(addKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(addKernel, 0, sizeof(void *), &inOutBuffer); + lzt::set_argument_value(addKernel, 1, sizeof(scalarVal), &scalarVal); + + uint64_t mutableKernelCommandId = 0; + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &mutableKernelCommandId)); + + ze_group_count_t groupCount{bufferSize / groupSizeX, 1, 1}; + lzt::append_launch_function(mutableCmdList, addKernel, &groupCount, events[0], + 0, nullptr); + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0])); + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], initBufferVal + scalarVal); + } + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[0])); + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[0])); + + // Mutate kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandKernelsExp( + mutableCmdList, 1, &mutableKernelCommandId, &mulKernel)); + + // Mutate all invalidated data + ze_mutable_group_count_exp_desc_t mutateGroupCount{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount.commandId = mutableKernelCommandId; + mutateGroupCount.pGroupCount = &groupCount; + ze_mutable_group_size_exp_desc_t mutateGroupSize{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize.commandId = mutableKernelCommandId; + mutateGroupSize.groupSizeX = groupSizeX; + mutateGroupSize.groupSizeY = groupSizeY; + mutateGroupSize.groupSizeZ = groupSizeZ; + mutateGroupSize.pNext = &mutateGroupCount; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg.commandId = mutableKernelCommandId; + mutateBufferKernelArg.argIndex = 0; + mutateBufferKernelArg.argSize = sizeof(void *); + mutateBufferKernelArg.pArgValue = &inOutBuffer; + mutateBufferKernelArg.pNext = &mutateGroupSize; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg.commandId = mutableKernelCommandId; + mutateScalarKernelArg.argIndex = 1; + mutateScalarKernelArg.argSize = sizeof(scalarVal); + mutateScalarKernelArg.pArgValue = &scalarVal; + mutateScalarKernelArg.pNext = &mutateBufferKernelArg; + mutableCmdDesc.pNext = &mutateScalarKernelArg; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp( + mutableCmdList, &mutableCmdDesc)); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0])); + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], (initBufferVal + scalarVal) * scalarVal); + } + eventPool.destroy_events(events); + lzt::free_memory(inOutBuffer); + lzt::destroy_function(addKernel); + lzt::destroy_function(mulKernel); +} + +TEST_F( + zeMutableCommandListTests, + GivenMutationOfKernelInstructionWithWaitListWhenCommandListIsClosedThenKernelIsReplacedAndWaitListRemainsUnchanged) { + if (!CheckExtensionSupport(ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1) || + !kernelInstructionSupport) { + GTEST_SKIP() << "ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION not " + "supported"; + } + const int32_t bufferSize = 16384; + const int32_t initBufferVal = 123; + const int32_t addVal = 456; + const int32_t mulVal = 789; + const int32_t subVal = 321; + + int32_t *inOutBuffer = reinterpret_cast( + lzt::allocate_host_memory(bufferSize * sizeof(int32_t))); + for (size_t i = 0; i < bufferSize; i++) { + inOutBuffer[i] = initBufferVal; + } + + lzt::zeEventPool eventPool; + const uint32_t eventsNumber = 2; + std::vector events(eventsNumber, nullptr); + eventPool.InitEventPool(context, eventsNumber, + ZE_EVENT_POOL_FLAG_HOST_VISIBLE); + eventPool.create_events(events, eventsNumber); + + uint32_t groupSizeX = 0; + uint32_t groupSizeY = 0; + uint32_t groupSizeZ = 0; + + ze_kernel_handle_t addKernel = lzt::create_function(module, "addValue"); + ze_kernel_handle_t mulKernel = lzt::create_function(module, "mulValue"); + ze_kernel_handle_t subKernel = lzt::create_function(module, "subValue"); + std::vector kernels{addKernel, mulKernel, subKernel}; + + lzt::suggest_group_size(addKernel, bufferSize, 1, 1, groupSizeX, groupSizeY, + groupSizeZ); + lzt::set_group_size(addKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(addKernel, 0, sizeof(void *), &inOutBuffer); + lzt::set_argument_value(addKernel, 1, sizeof(addVal), &addVal); + ze_group_count_t groupCount{bufferSize / groupSizeX, 1, 1}; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostSignal(events[0])); + lzt::append_launch_function(mutableCmdList, addKernel, &groupCount, events[1], + 0, nullptr); + + lzt::set_group_size(mulKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(mulKernel, 0, sizeof(void *), &inOutBuffer); + lzt::set_argument_value(mulKernel, 1, sizeof(mulVal), &mulVal); + + uint64_t mutableKernelCommandId = 0; + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &mutableKernelCommandId)); + + lzt::append_launch_function(mutableCmdList, mulKernel, &groupCount, nullptr, + 2, events.data()); + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0])); + const uint32_t firstResult = (initBufferVal + addVal) * mulVal; + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], firstResult); + } + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[1])); + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[1])); + + // Mutate kernel + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandKernelsExp( + mutableCmdList, 1, &mutableKernelCommandId, &subKernel)); + + // Mutate all invalidated data + ze_mutable_group_count_exp_desc_t mutateGroupCount{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount.commandId = mutableKernelCommandId; + mutateGroupCount.pGroupCount = &groupCount; + ze_mutable_group_size_exp_desc_t mutateGroupSize{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize.commandId = mutableKernelCommandId; + mutateGroupSize.groupSizeX = groupSizeX; + mutateGroupSize.groupSizeY = groupSizeY; + mutateGroupSize.groupSizeZ = groupSizeZ; + mutateGroupSize.pNext = &mutateGroupCount; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg.commandId = mutableKernelCommandId; + mutateBufferKernelArg.argIndex = 0; + mutateBufferKernelArg.argSize = sizeof(void *); + mutateBufferKernelArg.pArgValue = &inOutBuffer; + mutateBufferKernelArg.pNext = &mutateGroupSize; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg.commandId = mutableKernelCommandId; + mutateScalarKernelArg.argIndex = 1; + mutateScalarKernelArg.argSize = sizeof(subVal); + mutateScalarKernelArg.pArgValue = &subVal; + mutateScalarKernelArg.pNext = &mutateBufferKernelArg; + mutableCmdDesc.pNext = &mutateScalarKernelArg; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp( + mutableCmdList, &mutableCmdDesc)); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[1])); + const uint32_t secondResult = firstResult + addVal - subVal; + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer[i], secondResult); + } + eventPool.destroy_events(events); + lzt::free_memory(inOutBuffer); + lzt::destroy_function(addKernel); + lzt::destroy_function(mulKernel); + lzt::destroy_function(subKernel); +} + +TEST_F( + zeMutableCommandListTests, + GivenMutationOfMultipleKernelInstructionsAndEventsWhenCommandListIsClosedThenEverythingIsUpdatedCorrectly) { + if (!CheckExtensionSupport(ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1) || + !signalEventSupport || !waitEventsSupport || !kernelInstructionSupport) { + GTEST_SKIP() << "ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION not " + "supported"; + } + const int32_t bufferSize = 16384; + const int32_t initBufferVal = 100; + const int32_t addVal = 20; + const int32_t mulVal = 30; + const int32_t subVal = 40; + const int32_t divVal = 4; + + lzt::zeEventPool eventPool; + const uint32_t eventsNumber = 4; + std::vector events(eventsNumber, nullptr); + eventPool.InitEventPool(context, eventsNumber, + ZE_EVENT_POOL_FLAG_HOST_VISIBLE); + eventPool.create_events(events, eventsNumber); + + int32_t *inOutBuffer1 = reinterpret_cast( + lzt::allocate_host_memory(bufferSize * sizeof(int32_t))); + int32_t *inOutBuffer2 = reinterpret_cast( + lzt::allocate_host_memory(bufferSize * sizeof(int32_t))); + for (size_t i = 0; i < bufferSize; i++) { + inOutBuffer1[i] = initBufferVal; + inOutBuffer2[i] = initBufferVal; + } + + uint32_t groupSizeX = 0; + uint32_t groupSizeY = 0; + uint32_t groupSizeZ = 0; + + ze_kernel_handle_t addKernel = lzt::create_function(module, "addValue"); + ze_kernel_handle_t mulKernel = lzt::create_function(module, "mulValue"); + ze_kernel_handle_t subKernel = lzt::create_function(module, "subValue"); + ze_kernel_handle_t divKernel = lzt::create_function(module, "divValue"); + + uint64_t kernelCommandId1 = 0; + uint64_t kernelCommandId2 = 0; + uint64_t kernelCommandId3 = 0; + std::vector kernels{addKernel, mulKernel, subKernel, + divKernel}; + commandIdDesc.flags = ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT | + ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE | + ZE_MUTABLE_COMMAND_EXP_FLAG_SIGNAL_EVENT | + ZE_MUTABLE_COMMAND_EXP_FLAG_WAIT_EVENTS | + ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; + + lzt::suggest_group_size(addKernel, bufferSize, 1, 1, groupSizeX, groupSizeY, + groupSizeZ); + ze_group_count_t groupCount{bufferSize / groupSizeX, 1, 1}; + + // 1 addKernel + lzt::set_group_size(addKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(addKernel, 0, sizeof(void *), &inOutBuffer1); + lzt::set_argument_value(addKernel, 1, sizeof(addVal), &addVal); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernelCommandId1)); + lzt::append_launch_function(mutableCmdList, addKernel, &groupCount, events[0], + 0, nullptr); + // 2 mulKernel + lzt::set_group_size(mulKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(mulKernel, 0, sizeof(void *), &inOutBuffer1); + lzt::set_argument_value(mulKernel, 1, sizeof(mulVal), &mulVal); + + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernelCommandId2)); + lzt::append_launch_function(mutableCmdList, mulKernel, &groupCount, events[1], + 1, &events[0]); + // 3 subKernel + lzt::set_group_size(subKernel, groupSizeX, groupSizeY, groupSizeZ); + lzt::set_argument_value(subKernel, 0, sizeof(void *), &inOutBuffer1); + lzt::set_argument_value(subKernel, 1, sizeof(subVal), &subVal); + + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListGetNextCommandIdWithKernelsExp( + mutableCmdList, &commandIdDesc, kernels.size(), kernels.data(), + &kernelCommandId3)); + lzt::append_launch_function(mutableCmdList, subKernel, &groupCount, nullptr, + 2, &events[0]); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[0])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[1])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[0])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventHostReset(events[1])); + const uint32_t firstResult = ((initBufferVal + addVal) * mulVal) - subVal; + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer1[i], firstResult); + } + + // Update events + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandSignalEventExp( + mutableCmdList, kernelCommandId1, events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandSignalEventExp( + mutableCmdList, kernelCommandId2, events[3])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandWaitEventsExp( + mutableCmdList, kernelCommandId2, 1, &events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, + zeCommandListUpdateMutableCommandWaitEventsExp( + mutableCmdList, kernelCommandId3, 2, &events[2])); + + // Change kernels sequence from add, mul, sub to mul, sub, div + std::vector commandIds{kernelCommandId1, kernelCommandId2, + kernelCommandId3}; + std::vector newSequenceOfKenrels{mulKernel, subKernel, + divKernel}; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandKernelsExp( + mutableCmdList, 3, commandIds.data(), + newSequenceOfKenrels.data())); + + // Mutate invalidated data for kernel 1 + ze_mutable_group_count_exp_desc_t mutateGroupCount{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount.commandId = kernelCommandId1; + mutateGroupCount.pGroupCount = &groupCount; + ze_mutable_group_size_exp_desc_t mutateGroupSize{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize.commandId = kernelCommandId1; + mutateGroupSize.groupSizeX = groupSizeX; + mutateGroupSize.groupSizeY = groupSizeY; + mutateGroupSize.groupSizeZ = groupSizeZ; + mutateGroupSize.pNext = &mutateGroupCount; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg.commandId = kernelCommandId1; + mutateBufferKernelArg.argIndex = 0; + mutateBufferKernelArg.argSize = sizeof(void *); + mutateBufferKernelArg.pArgValue = &inOutBuffer2; + mutateBufferKernelArg.pNext = &mutateGroupSize; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg.commandId = kernelCommandId1; + mutateScalarKernelArg.argIndex = 1; + mutateScalarKernelArg.argSize = sizeof(mulVal); + mutateScalarKernelArg.pArgValue = &mulVal; + mutateScalarKernelArg.pNext = &mutateBufferKernelArg; + + // Mutate invalidated data for kernel 2 + ze_mutable_group_count_exp_desc_t mutateGroupCount2{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount2.commandId = kernelCommandId2; + mutateGroupCount2.pGroupCount = &groupCount; + mutateGroupCount2.pNext = &mutateScalarKernelArg; + ze_mutable_group_size_exp_desc_t mutateGroupSize2{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize2.commandId = kernelCommandId2; + mutateGroupSize2.groupSizeX = groupSizeX; + mutateGroupSize2.groupSizeY = groupSizeY; + mutateGroupSize2.groupSizeZ = groupSizeZ; + mutateGroupSize2.pNext = &mutateGroupCount2; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg2{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg2.commandId = kernelCommandId2; + mutateBufferKernelArg2.argIndex = 0; + mutateBufferKernelArg2.argSize = sizeof(void *); + mutateBufferKernelArg2.pArgValue = &inOutBuffer2; + mutateBufferKernelArg2.pNext = &mutateGroupSize2; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg2{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg2.commandId = kernelCommandId2; + mutateScalarKernelArg2.argIndex = 1; + mutateScalarKernelArg2.argSize = sizeof(subVal); + mutateScalarKernelArg2.pArgValue = &subVal; + mutateScalarKernelArg2.pNext = &mutateBufferKernelArg2; + + // Mutate invalidated data for kernel 3 + ze_mutable_group_count_exp_desc_t mutateGroupCount3{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_COUNT_EXP_DESC}; + mutateGroupCount3.commandId = kernelCommandId3; + mutateGroupCount3.pGroupCount = &groupCount; + mutateGroupCount3.pNext = &mutateScalarKernelArg2; + ze_mutable_group_size_exp_desc_t mutateGroupSize3{ + ZE_STRUCTURE_TYPE_MUTABLE_GROUP_SIZE_EXP_DESC}; + mutateGroupSize3.commandId = kernelCommandId3; + mutateGroupSize3.groupSizeX = groupSizeX; + mutateGroupSize3.groupSizeY = groupSizeY; + mutateGroupSize3.groupSizeZ = groupSizeZ; + mutateGroupSize3.pNext = &mutateGroupCount3; + ze_mutable_kernel_argument_exp_desc_t mutateBufferKernelArg3{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateBufferKernelArg3.commandId = kernelCommandId3; + mutateBufferKernelArg3.argIndex = 0; + mutateBufferKernelArg3.argSize = sizeof(void *); + mutateBufferKernelArg3.pArgValue = &inOutBuffer2; + mutateBufferKernelArg3.pNext = &mutateGroupSize3; + ze_mutable_kernel_argument_exp_desc_t mutateScalarKernelArg3{ + ZE_STRUCTURE_TYPE_MUTABLE_KERNEL_ARGUMENT_EXP_DESC}; + mutateScalarKernelArg3.commandId = kernelCommandId3; + mutateScalarKernelArg3.argIndex = 1; + mutateScalarKernelArg3.argSize = sizeof(divVal); + mutateScalarKernelArg3.pArgValue = &divVal; + mutateScalarKernelArg3.pNext = &mutateBufferKernelArg3; + + mutableCmdDesc.pNext = &mutateScalarKernelArg3; + EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListUpdateMutableCommandsExp( + mutableCmdList, &mutableCmdDesc)); + + lzt::close_command_list(mutableCmdList); + lzt::execute_command_lists(queue, 1, &mutableCmdList, nullptr); + lzt::synchronize(queue, std::numeric_limits::max()); + + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[0])); + EXPECT_EQ(ZE_RESULT_NOT_READY, zeEventQueryStatus(events[1])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[2])); + EXPECT_EQ(ZE_RESULT_SUCCESS, zeEventQueryStatus(events[3])); + const uint32_t secondResult = ((initBufferVal * mulVal) - subVal) / divVal; + for (size_t i = 0; i < bufferSize; i++) { + EXPECT_EQ(inOutBuffer2[i], secondResult); + } + + eventPool.destroy_events(events); + lzt::free_memory(inOutBuffer1); + lzt::free_memory(inOutBuffer2); + lzt::destroy_function(addKernel); + lzt::destroy_function(mulKernel); + lzt::destroy_function(subKernel); + lzt::destroy_function(divKernel); +} + class zeMutableCommandListTestsEvents : public zeMutableCommandListTests, public ::testing::WithParamInterface {