Skip to content

Commit

Permalink
Merge branch 'main' of github.com:homanp/superagent
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp committed Mar 7, 2024
2 parents b48080a + b2ded8b commit b25b248
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from app.api.tools import (
delete as api_delete_tool,
)
from app.api.tools import (
update as api_update_tool,
)
from app.api.workflow_configs.api.base import (
BaseApiAgentManager,
BaseApiDatasourceManager,
Expand All @@ -23,6 +26,9 @@
from app.models.request import (
Tool as ToolRequest,
)
from app.models.request import (
ToolUpdate as ToolUpdateRequest,
)
from prisma.enums import ToolType
from services.superrag import SuperRagService

Expand Down Expand Up @@ -61,7 +67,6 @@ async def _create_tool(self, assistant: dict, data: dict):

async def _add_tool(self, assistant: dict, data: dict):
new_tool = await self._create_tool(assistant, data)

assistant = await self.agent_manager.get_assistant(assistant)

try:
Expand All @@ -76,6 +81,21 @@ async def _add_tool(self, assistant: dict, data: dict):
except Exception as err:
logger.error(f"Error adding tool: {new_tool} - {assistant} - Error: {err}")

async def _update_tool(self, assistant: dict, data: dict):
tool = await self.agent_manager.get_tool(assistant, data)

try:
await api_update_tool(
tool_id=tool.id,
body=ToolUpdateRequest.parse_obj(data),
api_user=self.api_user,
)
logger.info(f"Updated tool: {tool.name} - {assistant.get('name')}")
except Exception as err:
logger.error(
f"Error updating tool: {tool} - {data} - {assistant} - Error: {err}"
)

async def _add_superrag_tool(self, assistant: dict, data: dict):
new_tool = {
**data,
Expand Down Expand Up @@ -111,6 +131,13 @@ async def _get_unique_index_name(self, datasource: dict, assistant: dict):

return unique_name

async def update_datasource(self, assistant: dict, data: dict):
"""
This method only updates the superrag tool, not the datasource in SuperRag.
To achive that, first delete the datasource and then add it again.
"""
await self._update_tool(assistant, data)

async def add_datasource(self, assistant: dict, data: dict):
data["index_name"] = await self._get_unique_index_name(data, assistant)

Expand Down
11 changes: 7 additions & 4 deletions libs/superagent/app/api/workflow_configs/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ async def transform_superrags(self):

await self._set_superrag_files(datasource)
await self._set_database_provider(datasource)
encoder = datasource.get("encoder") or DEFAULT_ENCODER_OPTIONS
rename_and_remove_keys(encoder, {"type": "provider"})
rename_and_remove_keys(encoder, {"name": "model_name"})

datasource["document_processor"] = {
"encoder": datasource.get("encoder") or DEFAULT_ENCODER_OPTIONS,
"encoder": encoder,
"unstructured": {
"hi_res_model_name": "detectron2_onnx",
"partition_strategy": "auto",
Expand All @@ -148,7 +151,7 @@ async def transform_superrags(self):
"splitter": {
"max_tokens": 400,
"min_tokens": 30,
"name": "semantic",
"name": "by_title",
"prefix_summary": True,
"prefix_title": True,
"rolling_window_size": 1,
Expand Down Expand Up @@ -177,8 +180,8 @@ async def _set_database_provider(self, datasource: dict):
}
else:
raise MissingVectorDatabaseProvider(
f"Vector database provider not found ({database_provider})."
f"Please configure it by going to the integrations page"
"Vector database provider not found."
"Please configure it by going to the integrations page"
)
remove_key_if_present(datasource, "database_provider")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def process_assistant(

new_superrag_processor = Processor(
self.api_user, self.api_manager
).get_superrag_processor(old_assistant)
).get_superrag_processor(new_assistant)

if old_assistant_type and new_assistant_type:
if old_assistant_type != new_assistant_type:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ async def process(self, old_data, new_data):

if old_datasource_name and new_datasource_name:
is_changed = compare_dicts(old_datasource, new_datasource)

if is_changed:
if (
is_changed.get("description") is not None
and len(is_changed.items()) == 1
):
await datasource_manager.update_datasource(
self.assistant,
new_datasource,
)
else:
await datasource_manager.delete_datasource(
self.assistant,
old_datasource,
Expand Down
8 changes: 6 additions & 2 deletions libs/superagent/app/api/workflow_configs/saml_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ class SuperragEncoderType(str, Enum):


class SuperragEncoder(BaseModel):
type: SuperragEncoderType
name: str
type: SuperragEncoderType = Field(
description="The provider of encoder to use for the index. e.g. `openai`"
)
name: str = Field(
description="The model name to use for the encoder. e.g. `text-embedding-3-small` for OpenAI's model"
)
dimensions: int


Expand Down

0 comments on commit b25b248

Please sign in to comment.