Skip to content

Commit

Permalink
ref: improve flexibility for size and model configurations to facilit…
Browse files Browse the repository at this point in the history
…ate future changes
  • Loading branch information
yuukiok committed Oct 1, 2024
1 parent d8686bb commit 3b60a13
Showing 1 changed file with 72 additions and 48 deletions.
120 changes: 72 additions & 48 deletions packages/web/src/pages/GenerateImagePage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,70 @@ const modeOptions = Object.values(GENERATION_MODES).map((mode) => ({
value: mode,
label: mode,
}));
type ModelInfo = {
supportedModes: GenerationMode[];
resolutionPresets: { value: string; label: string }[];
};
const stabilityAi2024ModelPresets = [
{ value: '1:1', label: '1024 x 1024' },
{ value: '5:4', label: '1088 x 896' },
{ value: '3:2', label: '1216 x 832' },
{ value: '16:9', label: '1344 x 768' },
{ value: '21:9', label: '1536 x 640' },
];
const modelInfo: Record<string, ModelInfo> = {
[STABILITY_AI_MODELS.STABLE_DIFFUSION_XL]: {
supportedModes: [
GENERATION_MODES.TEXT_IMAGE,
GENERATION_MODES.IMAGE_VARIATION,
GENERATION_MODES.INPAINTING,
GENERATION_MODES.OUTPAINTING,
],
resolutionPresets: [
{ value: '512 x 512', label: '512 x 512' },
{ value: '1024 x 1024', label: '1024 x 1024' },
{ value: '1280 x 768', label: '1280 x 768' },
{ value: '768 x 1280', label: '768 x 1280' },
],
},
[STABILITY_AI_MODELS.SD3_LARGE]: {
supportedModes: [
GENERATION_MODES.TEXT_IMAGE,
GENERATION_MODES.IMAGE_VARIATION,
],
resolutionPresets: stabilityAi2024ModelPresets,
},
[STABILITY_AI_MODELS.STABLE_IMAGE_CORE]: {
supportedModes: [GENERATION_MODES.TEXT_IMAGE],
resolutionPresets: stabilityAi2024ModelPresets,
},
[STABILITY_AI_MODELS.STABLE_IMAGE_ULTRA]: {
supportedModes: [GENERATION_MODES.TEXT_IMAGE],
resolutionPresets: stabilityAi2024ModelPresets,
},
};

const getModeOptions = (imageGenModelId: string) => {
if (imageGenModelId === STABILITY_AI_MODELS.STABLE_DIFFUSION_XL) {
return modeOptions;
} else if (imageGenModelId === STABILITY_AI_MODELS.SD3_LARGE) {
return modeOptions.filter(
(mode) =>
mode.value === GENERATION_MODES.TEXT_IMAGE ||
mode.value === GENERATION_MODES.IMAGE_VARIATION
);
} else if (
imageGenModelId === STABILITY_AI_MODELS.STABLE_IMAGE_CORE ||
imageGenModelId === STABILITY_AI_MODELS.STABLE_IMAGE_ULTRA
) {
return modeOptions.filter(
(mode) => mode.value === GENERATION_MODES.TEXT_IMAGE
);
if (imageGenModelId in modelInfo) {
return modelInfo[imageGenModelId].supportedModes.map((mode) => ({
value: mode,
label: mode,
}));
} else {
return [
{
value: GENERATION_MODES.TEXT_IMAGE,
label: GENERATION_MODES.TEXT_IMAGE,
},
];
}
};
const getResolutionPresets = (imageGenModelId: string) => {
if (imageGenModelId in modelInfo) {
return modelInfo[imageGenModelId].resolutionPresets;
} else {
return stabilityAi2024ModelPresets;
}

return modeOptions;
};

type StateType = {
Expand Down Expand Up @@ -109,28 +153,6 @@ type StateType = {
};

const useGenerateImagePageState = create<StateType>((set, get) => {
const getResolutionPresets = (imageGenModelId: string) => {
if (
[
'stability.sd3-large-v1:0',
'stability.stable-image-core-v1:0',
'stability.stable-image-ultra-v1:0',
].includes(imageGenModelId)
) {
return [
{ value: '1:1', label: '1024 x 1024' },
{ value: '5:4', label: '1088 x 896' },
{ value: '3:2', label: '1216 x 832' },
{ value: '16:9', label: '1344 x 768' },
{ value: '21:9', label: '1536 x 640' },
];
} else {
return ['512 x 512', '1024 x 1024', '1280 x 768', '768 x 1280'].map(
(s) => ({ value: s, label: s })
);
}
};

const INIT_STATE = {
imageGenModelId: '',
prompt: '',
Expand Down Expand Up @@ -400,9 +422,9 @@ const GenerateImagePage: React.FC = () => {
(option) => option.value
);
if (!availableModes.includes(generationMode)) {
setGenerationMode(modeOptions[0].value as GenerationMode);
setGenerationMode(availableModes[0]);
}
}, [imageGenModelId, generationMode, modeOptions, setGenerationMode]);
}, [imageGenModelId, generationMode, setGenerationMode]);

useEffect(() => {
updateSystemContextByModel();
Expand Down Expand Up @@ -451,6 +473,13 @@ const GenerateImagePage: React.FC = () => {
clearImage();
setGenerating(true);

const modelConfig = modelInfo[imageGenModelId];
if (!modelConfig) {
console.error(`Unknown model: ${imageGenModelId}`);
setGenerating(false);
return;
}

const promises = new Array(imageSample).fill('').map((_, idx) => {
let _seed = seed[idx];
if (_seed < 0) {
Expand Down Expand Up @@ -497,13 +526,8 @@ const GenerateImagePage: React.FC = () => {
};
}

if (
[
'stability.sd3-large-v1:0',
'stability.stable-image-core-v1:0',
'stability.stable-image-ultra-v1:0',
].includes(imageGenModelId)
) {
// 解像度の設定
if (modelConfig.resolutionPresets[0].value.includes(':')) {
params = {
...params,
aspectRatio: resolution.value,
Expand Down

0 comments on commit 3b60a13

Please sign in to comment.