Skip to content

Commit

Permalink
feat(win): better gpu detection (containers#1141)
Browse files Browse the repository at this point in the history
* feat(win): fast gpu detection

Signed-off-by: axel7083 <[email protected]>

* fix(WinGPUDectector): registry value can be binary

Signed-off-by: axel7083 <[email protected]>

* fix: hex encoding

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Jun 4, 2024
1 parent 5a14c18 commit 9672180
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 267 deletions.
3 changes: 2 additions & 1 deletion packages/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,21 @@
"watch": "vite --mode development build -w"
},
"dependencies": {
"fast-xml-parser": "^4.4.0",
"isomorphic-git": "^1.25.10",
"mustache": "^4.2.0",
"openai": "^4.47.3",
"postman-code-generators": "^1.10.1",
"postman-collection": "^4.4.0",
"semver": "^7.6.2",
"winreg": "^1.2.5",
"xml-js": "^1.6.11"
},
"devDependencies": {
"@podman-desktop/api": "0.0.202404101645-5d46ba5",
"@types/js-yaml": "^4.0.9",
"@types/node": "^20",
"@types/postman-collection": "^3.5.10",
"@types/winreg": "^1.2.36",
"vitest": "^1.6.0",
"@types/mustache": "^4.2.5"
}
Expand Down
79 changes: 3 additions & 76 deletions packages/backend/src/managers/GPUManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { expect, test, vi, beforeEach } from 'vitest';
import { containerEngine, env } from '@podman-desktop/api';
import type { ContainerInspectInfo, ContainerProviderConnection, ImageInfo, Webview } from '@podman-desktop/api';
import { env } from '@podman-desktop/api';
import type { Webview } from '@podman-desktop/api';
import { GPUManager } from './GPUManager';
import { getImageInfo, getProviderContainerConnection } from '../utils/inferenceUtils';
import { XMLParser } from 'fast-xml-parser';

vi.mock('../utils/inferenceUtils', () => ({
getProviderContainerConnection: vi.fn(),
Expand All @@ -29,12 +27,6 @@ vi.mock('../utils/inferenceUtils', () => ({

vi.mock('@podman-desktop/api', async () => {
return {
containerEngine: {
createContainer: vi.fn(),
logsContainer: vi.fn(),
deleteContainer: vi.fn(),
inspectContainer: vi.fn(),
},
env: {
isWindows: false,
},
Expand All @@ -52,47 +44,6 @@ const webviewMock = {
beforeEach(() => {
vi.resetAllMocks();
vi.mocked(webviewMock.postMessage).mockResolvedValue(true);

vi.mocked(getProviderContainerConnection).mockReturnValue({
providerId: 'dummyProviderId',
connection: {} as unknown as ContainerProviderConnection,
});
vi.mocked(getImageInfo).mockResolvedValue({
engineId: 'dummyEngineId',
Id: 'dummyImageId',
} as unknown as ImageInfo);

vi.mocked(containerEngine.createContainer).mockResolvedValue({
id: 'dummyContainerId',
});

vi.mocked(containerEngine.logsContainer).mockImplementation(async (_engineId, _containerId, callback) => {
callback('', '</nvidia_smi_log>');
});

vi.mocked(XMLParser).mockReturnValue({
parse: vi.fn().mockReturnValue({
nvidia_smi_log: {
attached_gpus: 1,
cuda_version: 2,
driver_version: 3,
timestamp: 4,
gpu: {
uuid: 'dummyUUID',
product_name: 'dummyProductName',
},
},
}),
} as unknown as XMLParser);

vi.mocked(containerEngine.inspectContainer).mockImplementation(async (_engineId, _id) => {
return {
State: {
Running: false,
ExitCode: 0,
},
} as unknown as ContainerInspectInfo;
});
});

test('post constructor should have no items', () => {
Expand All @@ -106,29 +57,5 @@ test('non-windows host should throw error', async () => {
const manager = new GPUManager(webviewMock);
await expect(() => {
return manager.collectGPUs();
}).rejects.toThrowError('Cannot collect GPUs information on this machine.');
});

test('windows host should start then delete container with proper configuration', async () => {
vi.mocked(env).isWindows = true;

const manager = new GPUManager(webviewMock);
const gpus = await manager.collectGPUs({
providerId: 'dummyProviderId',
});

expect(gpus.length).toBe(1);
expect(gpus[0].uuid).toBe('dummyUUID');
expect(gpus[0].product_name).toBe('dummyProductName');

expect(getProviderContainerConnection).toHaveBeenCalledWith('dummyProviderId');

expect(containerEngine.createContainer).toHaveBeenCalledWith('dummyEngineId', {
Image: 'dummyImageId',
Cmd: expect.anything(),
Detach: false,
Entrypoint: '/usr/bin/sh',
HostConfig: expect.anything(),
});
expect(containerEngine.deleteContainer).toHaveBeenCalledWith('dummyEngineId', 'dummyContainerId');
}).rejects.toThrowError();
});
155 changes: 18 additions & 137 deletions packages/backend/src/managers/GPUManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,159 +15,40 @@
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import {
containerEngine,
type Disposable,
type Webview,
type ImageInfo,
type PullEvent,
type ContainerCreateOptions,
env,
} from '@podman-desktop/api';
import { getImageInfo, getProviderContainerConnection } from '../utils/inferenceUtils';
import { XMLParser } from 'fast-xml-parser';
import { type Disposable, type Webview } from '@podman-desktop/api';
import type { IGPUInfo } from '@shared/src/models/IGPUInfo';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';

export const CUDA_UBI8_IMAGE = 'nvcr.io/nvidia/cuda:12.3.2-devel-ubi8';
import type { IWorker } from '../workers/IWorker';
import { WinGPUDetector } from '../workers/gpu/WinGPUDetector';
import { platform } from 'node:os';

/**
* @experimental
*/
export class GPUManager extends Publisher<IGPUInfo[]> implements Disposable {
// Map uuid -> info
#gpus: Map<string, IGPUInfo>;
#gpus: IGPUInfo[];

#workers: IWorker<void, IGPUInfo[]>[];

constructor(webview: Webview) {
super(webview, Messages.MSG_GPUS_UPDATE, () => this.getAll());
this.#gpus = new Map();
}
dispose(): void {
this.#gpus.clear();
}

getAll(): IGPUInfo[] {
return Array.from(this.#gpus.values());
// init properties
this.#gpus = [];
this.#workers = [new WinGPUDetector()];
}

async collectGPUs(options?: { providerId: string }): Promise<IGPUInfo[]> {
if (!env.isWindows) {
throw new Error('Cannot collect GPUs information on this machine.');
}

const provider = getProviderContainerConnection(options?.providerId);
const imageInfo: ImageInfo = await getImageInfo(provider.connection, CUDA_UBI8_IMAGE, (_event: PullEvent) => {});

const result = await containerEngine.createContainer(
imageInfo.engineId,
this.getWindowsContainerCreateOptions(imageInfo),
);

const exitCode = await this.waitForExit(imageInfo.engineId, result.id);
if (exitCode !== 0) throw new Error(`nvidia CUDA Container exited with code ${exitCode}.`);

try {
const logs = await this.getLogs(imageInfo.engineId, result.id);
const parsed: {
nvidia_smi_log: {
attached_gpus: number;
cuda_version: number;
driver_version: number;
timestamp: string;
gpu: IGPUInfo;
};
} = new XMLParser().parse(logs);
dispose(): void {}

if (parsed.nvidia_smi_log.attached_gpus > 1) throw new Error('machine with more than one GPU are not supported.');

this.#gpus.set(parsed.nvidia_smi_log.gpu.uuid, parsed.nvidia_smi_log.gpu);
this.notify();
return this.getAll();
} finally {
await containerEngine.deleteContainer(imageInfo.engineId, result.id);
}
}

private getWindowsContainerCreateOptions(imageInfo: ImageInfo): ContainerCreateOptions {
return {
Image: imageInfo.Id,
Detach: false,
HostConfig: {
AutoRemove: false,
Mounts: [
{
Target: '/usr/lib/wsl',
Source: '/usr/lib/wsl',
Type: 'bind',
},
],
DeviceRequests: [
{
Capabilities: [['gpu']],
Count: -1, // -1: all
},
],
Devices: [
{
PathOnHost: '/dev/dxg',
PathInContainer: '/dev/dxg',
CgroupPermissions: 'r',
},
],
},
Entrypoint: '/usr/bin/sh',
Cmd: [
'-c',
'/usr/bin/ln -s /usr/lib/wsl/lib/* /usr/lib64/ && PATH="${PATH}:/usr/lib/wsl/lib/" && nvidia-smi -x -q',
],
};
}

private waitForExit(engineId: string, containerId: string): Promise<number> {
return new Promise<number>((resolve, reject) => {
let retry = 0;
const interval = setInterval(() => {
if (retry === 3) {
reject(new Error('timeout: container never exited.'));
return;
}

retry++;

containerEngine
.inspectContainer(engineId, containerId)
.then(inspectInfo => {
if (inspectInfo.State.Running) return;

clearInterval(interval);
resolve(inspectInfo.State.ExitCode);
})
.catch((err: unknown) => {
console.error('Something went wrong while trying to inspect container', err);
clearInterval(interval);
reject(new Error(`Failed to inspect container ${containerId}.`));
});
}, 2000);
});
getAll(): IGPUInfo[] {
return this.#gpus;
}

private getLogs(engineId: string, containerId: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
const interval = setTimeout(() => {
reject(new Error('timeout'));
}, 10000);
async collectGPUs(): Promise<IGPUInfo[]> {
const worker = this.#workers.find(worker => worker.enabled());
if (worker === undefined) throw new Error(`no worker enable to collect GPU on platform ${platform}`);

let logs = '';
containerEngine
.logsContainer(engineId, containerId, (name, data) => {
logs += data;
if (data.includes('</nvidia_smi_log>')) {
clearTimeout(interval);
resolve(logs);
}
})
.catch(reject);
});
this.#gpus = await worker.perform();
return this.getAll();
}
}
Loading

0 comments on commit 9672180

Please sign in to comment.