Skip to content

Commit

Permalink
add dalle3 command
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Apr 18, 2024
1 parent 956a7ae commit 34fa4a2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
26 changes: 26 additions & 0 deletions discord_tron_master/classes/openai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,29 @@ async def turbo_completion(self, role, prompt, **kwargs):
return choice.text

return response.choices[0].message.content

async def retrieve_image(self, url: str):
import requests
response = requests.get(url)
return response.content

async def dalle_image_generate(self, prompt, user_config: dict):
resolution = f"{user_config.get('width', 1024)}x{user_config.get('height', 1024)}"
try:
response = openai.images.generate(
model="dall-e-3",
prompt=f"I NEED to test how the tool works with extremely simple prompts. DO NOT add any detail, just use it AS-IS: {prompt}",
size=resolution,
quality="standard",
n=1,
)
url = response.data[0].url
# retrieve URL, return Image
image_obj = await self.retrieve_image(url)
if not hasattr(image_obj, "size"):
logging.error(f"Image object does not have a size attribute. Returning None.")
logging.debug(f"Response from OpenAI: {response}")
return None
except Exception as e:
logging.error(f"Error generating image: {e}")
return None
9 changes: 9 additions & 0 deletions discord_tron_master/cogs/image/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ async def generate_range(self, ctx, count, *, prompt):
for i in range(0, int(count)):
await self.generate(ctx, prompt=prompt)

@commands.command(name="dalle", help="Generates an image based on the given prompt using DALL-E.")
async def generate_dalle(self, ctx, *, prompt):
if guild_config.is_channel_banned(ctx.guild.id, ctx.channel.id):
return
from discord_tron_master.classes.openai.text import GPT
gpt = GPT()
image_output = await gpt.dalle_image_generate(prompt=prompt, user_config=self.config.get_user_config(user_id=ctx.author.id))
await ctx.channel.send(file=discord_lib.File(BytesIO(image_output), "image.png"))

@commands.command(name="sd3", help="Generates an image based on the given prompt using Stable Diffusion 3.")
async def generate_sd3(self, ctx, *, prompt):
if guild_config.is_channel_banned(ctx.guild.id, ctx.channel.id):
Expand Down

0 comments on commit 34fa4a2

Please sign in to comment.