diff --git a/packages/frontend/src/App.svelte b/packages/frontend/src/App.svelte index e53167bb6..654e6a7e9 100644 --- a/packages/frontend/src/App.svelte +++ b/packages/frontend/src/App.svelte @@ -13,6 +13,7 @@ import Recipe from '/@/pages/Recipe.svelte'; import Model from './pages/Model.svelte'; import { onMount } from 'svelte'; import { getRouterState } from '/@/utils/client'; +import CreateService from '/@/pages/CreateService.svelte'; import Services from '/@/pages/InferenceServers.svelte'; import ServiceDetails from '/@/pages/InferenceServerDetails.svelte'; @@ -60,6 +61,7 @@ onMount(() => { + @@ -69,7 +71,11 @@ onMount(() => { - + {#if meta.params.id === 'create'} + + {:else} + + {/if} diff --git a/packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts b/packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts index b98d239f1..ec3fcd769 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts +++ b/packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts @@ -21,6 +21,7 @@ import { test, expect, vi, beforeEach } from 'vitest'; import { fireEvent, render, screen, waitFor } from '@testing-library/svelte'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import ModelColumnActions from '/@/lib/table/model/ModelColumnActions.svelte'; +import { router } from 'tinro'; const mocks = vi.hoisted(() => ({ requestRemoveLocalModel: vi.fn(), @@ -71,6 +72,9 @@ test('Expect folder and delete button in document', async () => { const deleteBtn = screen.getByTitle('Delete Model'); expect(deleteBtn).toBeInTheDocument(); + const rocketBtn = screen.getByTitle('Create Model Service'); + expect(rocketBtn).toBeInTheDocument(); + const downloadBtn = screen.queryByTitle('Download Model'); expect(downloadBtn).toBeNull(); }); @@ -94,6 +98,9 @@ test('Expect download button in document', async () => { const deleteBtn = screen.queryByTitle('Delete Model'); expect(deleteBtn).toBeNull(); + const rocketBtn = screen.queryByTitle('Create Model Service'); + expect(rocketBtn).toBeNull(); + const downloadBtn = screen.getByTitle('Download Model'); expect(downloadBtn).toBeInTheDocument(); }); @@ -119,3 +126,33 @@ test('Expect downloadModel to be call on click', async () => { expect(mocks.downloadModel).toHaveBeenCalledWith('my-model'); }); }); + +test('Expect router to be called when rocket icon clicked', async () => { + const gotoMock = vi.spyOn(router, 'goto'); + const replaceMock = vi.spyOn(router.location.query, 'replace'); + + const object: ModelInfo = { + id: 'my-model', + description: '', + hw: '', + license: '', + name: '', + registry: '', + url: '', + file: { + file: 'file', + creation: new Date(), + size: 1000, + path: 'path', + }, + }; + render(ModelColumnActions, { object }); + + const rocketBtn = screen.getByTitle('Create Model Service'); + + await fireEvent.click(rocketBtn); + await waitFor(() => { + expect(gotoMock).toHaveBeenCalledWith('/service/create'); + expect(replaceMock).toHaveBeenCalledWith({ 'model-id': 'my-model' }); + }); +}); diff --git a/packages/frontend/src/lib/table/model/ModelColumnActions.svelte b/packages/frontend/src/lib/table/model/ModelColumnActions.svelte index 8c98723b1..be46c46c0 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnActions.svelte +++ b/packages/frontend/src/lib/table/model/ModelColumnActions.svelte @@ -1,9 +1,10 @@ {#if object.file !== undefined} + { + return { + modelsInfoSubscribeMock: vi.fn(), + modelsInfoQueriesMock: { + subscribe: (f: (msg: any) => void) => { + f(mocks.modelsInfoSubscribeMock()); + return () => {}; + }, + }, + }; +}); + +vi.mock('../stores/modelsInfo', async () => { + return { + modelsInfo: mocks.modelsInfoQueriesMock, + }; +}); + +vi.mock('../utils/client', async () => ({ + studioClient: { + createInferenceServer: vi.fn(), + getHostFreePort: vi.fn(), + }, +})); + +beforeEach(() => { + vi.resetAllMocks(); + mocks.modelsInfoSubscribeMock.mockReturnValue([]); + vi.mocked(studioClient.createInferenceServer).mockResolvedValue(undefined); + vi.mocked(studioClient.getHostFreePort).mockResolvedValue(8888); +}); + +test('create button should be disabled when no model id provided', async () => { + render(CreateService); + + await vi.waitFor(() => { + const createBtn = screen.getByTitle('Create service'); + expect(createBtn).toBeDefined(); + expect(createBtn.attributes.getNamedItem('disabled')).toBeTruthy(); + }); +}); + +test('expect error message to be displayed when no model locally', async () => { + render(CreateService); + + await vi.waitFor(() => { + const alert = screen.getByRole('alert'); + expect(alert).toBeDefined(); + }); +}); + +test('expect error message to be hidden when models locally', () => { + mocks.modelsInfoSubscribeMock.mockReturnValue([{ id: 'random', file: true }]); + render(CreateService); + + const alert = screen.queryByRole('alert'); + expect(alert).toBeNull(); +}); + +test('button click should call createInferenceServer', async () => { + mocks.modelsInfoSubscribeMock.mockReturnValue([{ id: 'random', file: true }]); + render(CreateService); + + let createBtn: HTMLElement | undefined = undefined; + await vi.waitFor(() => { + createBtn = screen.getByTitle('Create service'); + expect(createBtn).toBeDefined(); + }); + + if (createBtn === undefined) throw new Error('createBtn undefined'); + + await fireEvent.click(createBtn); + expect(vi.mocked(studioClient.createInferenceServer)).toHaveBeenCalledWith({ + modelsInfo: [{ id: 'random', file: true }], + port: 8888, + }); +}); diff --git a/packages/frontend/src/pages/CreateService.svelte b/packages/frontend/src/pages/CreateService.svelte new file mode 100644 index 000000000..e60fca182 --- /dev/null +++ b/packages/frontend/src/pages/CreateService.svelte @@ -0,0 +1,116 @@ + + + + +
+
+ + + + {#if localModels.length === 0} +
+ + +
+ {/if} + + + +
+
+
+ +
+
+
+
+
diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index 2e54b2f3b..f4f774f77 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -26,7 +26,7 @@ const columns: Column[] = [ new Column('HW Compat', { width: '1fr', renderer: ModelColumnHw }), new Column('Registry', { width: '2fr', renderer: ModelColumnRegistry }), new Column('License', { width: '2fr', renderer: ModelColumnLicense }), - new Column('Actions', { align: 'right', width: '80px', renderer: ModelColumnActions }), + new Column('Actions', { align: 'right', width: '120px', renderer: ModelColumnActions }), ]; const row = new Row({});