);
};
-
diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css
index e92e0bdb5..a7057d36a 100644
--- a/gui/pages/Content/Marketplace/Market.module.css
+++ b/gui/pages/Content/Marketplace/Market.module.css
@@ -378,6 +378,7 @@
width: 33% !important;
background-color: rgb(39, 35, 53);
border-radius: 8px;
+ flex-direction: column;
}
.tool_description {
@@ -503,12 +504,20 @@
.markdown_style{
color:white;
}
+
.markdown_style img {
max-width: 100%;
height: auto;
}
+
.markdown_container{
height:68vh;
overflow-y:scroll;
overflow-x:hidden;
+}
+
+.tools_icon{
+ border-radius: 25px;
+ background: black;
+ position: relative;
}
\ No newline at end of file
diff --git a/gui/pages/Content/Marketplace/MarketKnowledge.js b/gui/pages/Content/Marketplace/MarketKnowledge.js
new file mode 100644
index 000000000..31ba5cc51
--- /dev/null
+++ b/gui/pages/Content/Marketplace/MarketKnowledge.js
@@ -0,0 +1,91 @@
+import React, {useEffect, useState} from "react";
+import Image from "next/image";
+import styles from './Market.module.css';
+import styles1 from '../Knowledge/Knowledge.module.css';
+import {EventBus} from "@/utils/eventBus";
+import {loadingTextEffect} from "@/utils/utils";
+import axios from 'axios';
+import {fetchKnowledgeTemplateList} from "@/pages/api/DashboardService";
+
+export default function MarketKnowledge() {
+ const [knowledgeTemplates, setKnowledgeTemplates] = useState([])
+ const [showMarketplace, setShowMarketplace] = useState(false);
+ const [isLoading, setIsLoading] = useState(true)
+ const [loadingText, setLoadingText] = useState("Loading Knowledge Templates");
+
+ useEffect(() => {
+ loadingTextEffect('Loading Knowledge Templates', setLoadingText, 500);
+
+ if (window.location.href.toLowerCase().includes('marketplace')) {
+ setShowMarketplace(true);
+ axios.get(`https://app.superagi.com/api/knowledge/get/list?page=0`)
+ .then((response) => {
+ const data = response.data || [];
+ setKnowledgeTemplates(data);
+ setIsLoading(false);
+ })
+ .catch((error) => {
+ console.error('Error fetching knowledge templates:', error);
+ });
+ } else {
+ fetchKnowledgeTemplateList()
+ .then((response) => {
+ const data = response.data || [];
+ setKnowledgeTemplates(data);
+ setIsLoading(false);
+ })
+ .catch((error) => {
+ console.error('Error fetching knowledge templates:', error);
+ });
+ }
+ }, []);
+
+ function handleTemplateClick(item) {
+ const contentType = 'knowledge_template';
+ EventBus.emit('openTemplateDetails', {item, contentType});
+ }
+
+ return (
+
+ )
+}
\ No newline at end of file
diff --git a/gui/pages/Content/Marketplace/MarketTools.js b/gui/pages/Content/Marketplace/MarketTools.js
index 1f76c4a32..87af9f919 100644
--- a/gui/pages/Content/Marketplace/MarketTools.js
+++ b/gui/pages/Content/Marketplace/MarketTools.js
@@ -3,7 +3,7 @@ import Image from "next/image";
import styles from './Market.module.css';
import {fetchToolTemplateList} from "@/pages/api/DashboardService";
import {EventBus} from "@/utils/eventBus";
-import {loadingTextEffect, excludedToolkits} from "@/utils/utils";
+import {loadingTextEffect, excludedToolkits, returnToolkitIcon} from "@/utils/utils";
import axios from 'axios';
export default function MarketTools() {
@@ -51,25 +51,17 @@ export default function MarketTools() {
))}
@@ -337,18 +414,30 @@ export default function Content({env, selectedView, selectedProjectId, organisat
))}
diff --git a/gui/pages/Dashboard/Settings/AddDatabase.js b/gui/pages/Dashboard/Settings/AddDatabase.js
new file mode 100644
index 000000000..0383f972a
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/AddDatabase.js
@@ -0,0 +1,338 @@
+import React, {useState, useEffect} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {
+ createInternalId,
+ removeTab,
+ returnDatabaseIcon,
+ setLocalStorageArray,
+ setLocalStorageValue
+} from "@/utils/utils";
+import knowledgeStyles from "@/pages/Content/Knowledge/Knowledge.module.css";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
+import Image from "next/image";
+import styles1 from "@/pages/Content/Agents/Agents.module.css";
+import {connectPinecone, connectQdrant, fetchVectorDBList} from "@/pages/api/DashboardService";
+
+export default function AddDatabase({internalId, sendDatabaseDetailsData}) {
+ const [activeView, setActiveView] = useState('select_database');
+ const [vectorDatabases, setVectorDatabases] = useState(null);
+ const [selectedDB, setSelectedDB] = useState('');
+ const [databaseName, setDatabaseName] = useState('database name');
+ const [collections, setCollections] = useState(['']);
+
+ const [pineconeApiKey, setPineconeApiKey] = useState('');
+ const [pineconeEnvironment, setPineconeEnvironment] = useState('');
+
+ const [qdrantApiKey, setQdrantApiKey] = useState('');
+ const [qdrantURL, setQdrantURL] = useState('');
+ const [qdrantPort, setQdrantPort] = useState(8001);
+ const [connectText, setConnectText] = useState('Connect');
+
+ useEffect(() => {
+ const active_view = localStorage.getItem('add_database_tab_' + String(internalId));
+ if (active_view) {
+ setActiveView(active_view);
+ }
+
+ const db_name = localStorage.getItem('db_name_' + String(internalId));
+ if (db_name) {
+ setDatabaseName(db_name);
+ }
+
+ const db_collections = localStorage.getItem('db_collections_' + String(internalId));
+ if (db_collections) {
+ setCollections(JSON.parse(db_collections));
+ }
+
+ const pinecone_api = localStorage.getItem('pincone_api_' + String(internalId));
+ if (pinecone_api) {
+ setPineconeApiKey(pinecone_api);
+ }
+
+ const pinecone_env = localStorage.getItem('pinecone_env_' + String(internalId));
+ if (pinecone_env) {
+ setPineconeEnvironment(pinecone_env);
+ }
+
+ const qdrant_api = localStorage.getItem('qdrant_api_' + String(internalId));
+ if (qdrant_api) {
+ setQdrantApiKey(qdrant_api);
+ }
+
+ const qdrant_url = localStorage.getItem('qdrant_url_' + String(internalId));
+ if (qdrant_url) {
+ setQdrantURL(qdrant_url);
+ }
+
+ const qdrant_port = localStorage.getItem('qdrant_port_' + String(internalId));
+ if (qdrant_port) {
+ setQdrantPort(Number(qdrant_port));
+ }
+ }, [internalId]);
+
+ useEffect(() => {
+ fetchVectorDBList()
+ .then((response) => {
+ const data = response.data || [];
+ setVectorDatabases(data);
+ const selected_db = localStorage.getItem('selected_db_' + String(internalId));
+ setSelectedDB(selected_db ? selected_db : data[0].name || '');
+ })
+ .catch((error) => {
+ console.error('Error fetching vector databases:', error);
+ });
+ }, [internalId]);
+
+ const handleNameChange = (event) => {
+ setLocalStorageValue('db_name_' + String(internalId), event.target.value, setDatabaseName);
+ }
+
+ const handlePineconeAPIKeyChange = (event) => {
+ setLocalStorageValue('pincone_api_' + String(internalId), event.target.value, setPineconeApiKey);
+ }
+
+ const handlePineconeEnvironmentChange = (event) => {
+ setLocalStorageValue('pinecone_env_' + String(internalId), event.target.value, setPineconeEnvironment);
+ }
+
+ const handleQdrantAPIKeyChange = (event) => {
+ setLocalStorageValue('qdrant_api_' + String(internalId), event.target.value, setQdrantApiKey);
+ }
+
+ const handleQdrantURLChange = (event) => {
+ setLocalStorageValue('qdrant_url_' + String(internalId), event.target.value, setQdrantURL);
+ }
+
+ const handleQdrantPortChange = (event) => {
+ setLocalStorageValue('qdrant_port_' + String(internalId), event.target.value, setQdrantPort);
+ }
+
+ const addCollection = () => {
+ setLocalStorageArray("db_collections_" + String(internalId), [...collections, 'collection name'], setCollections);
+ };
+
+ const handleCollectionChange = (index, newValue) => {
+ const updatedCollections = [...collections];
+ updatedCollections[index] = newValue;
+ setLocalStorageArray("db_collections_" + String(internalId), updatedCollections, setCollections);
+ };
+
+ const handleCollectionDelete = (index) => {
+ const updatedCollections = [...collections];
+ updatedCollections.splice(index, 1);
+ setLocalStorageArray("db_collections_" + String(internalId), updatedCollections, setCollections);
+ };
+
+ const connectResponse = (data) => {
+ if (!data) {
+ return;
+ }
+ toast.success("Database connected successfully", {autoClose: 1800});
+ setConnectText("Connected");
+ sendDatabaseDetailsData({id: data.id, name: data.name, contentType: "Database", internalId: createInternalId()});
+ }
+
+ const connectDatabase = () => {
+ if (databaseName.replace(/\s/g, '') === '') {
+ toast.error("Database name can't be blank", {autoClose: 1800});
+ return;
+ }
+
+ if (selectedDB === 'Pinecone') {
+ if (pineconeApiKey.replace(/\s/g, '') === '') {
+ toast.error("Pinecone API key is empty", {autoClose: 1800});
+ return;
+ }
+
+ if (pineconeEnvironment.replace(/\s/g, '') === '') {
+ toast.error("Pinecone environment is empty", {autoClose: 1800});
+ return;
+ }
+
+ setConnectText("Connecting...");
+
+ const pineconeData = {
+ "name": databaseName,
+ "collections": collections,
+ "api_key": pineconeApiKey,
+ "environment": pineconeEnvironment,
+ }
+
+ connectPinecone(pineconeData)
+ .then((response) => {
+ connectResponse(response.data);
+ })
+ .catch((error) => {
+ toast.error("Unable to connect database", {autoClose: 1800});
+ console.error('Error fetching vector databases:', error);
+ setConnectText("Connect");
+ });
+ }
+
+ if (selectedDB === 'Qdrant') {
+ if (qdrantApiKey.replace(/\s/g, '') === '') {
+ toast.error("Qdrant API key is empty", {autoClose: 1800});
+ return;
+ }
+
+ if (qdrantURL.replace(/\s/g, '') === '') {
+ toast.error("Qdrant URL is empty", {autoClose: 1800});
+ return;
+ }
+
+ if (String(qdrantPort).replace(/\s/g, '') === '') {
+ toast.error("Qdrant port can't be blank", {autoClose: 1800});
+ return;
+ }
+
+ setConnectText("Connecting...");
+
+ const qdrantData = {
+ "name": databaseName,
+ "collections": collections,
+ "api_key": qdrantApiKey,
+ "url": qdrantURL,
+ "port": qdrantPort
+ }
+
+ connectQdrant(qdrantData)
+ .then((response) => {
+ connectResponse(response.data);
+ })
+ .catch((error) => {
+ toast.error("Unable to connect database", {autoClose: 1800});
+ console.error('Error fetching vector databases:', error);
+ setConnectText("Connect");
+ });
+ }
+ }
+
+ const proceedAddDatabase = () => {
+ if (selectedDB === null) {
+ toast.error("Please select a database", {autoClose: 1800});
+ return;
+ }
+
+ setLocalStorageValue('add_database_tab_' + String(internalId), 'form_database', setActiveView)
+ }
+
+ return (<>
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/Dashboard/Settings/Database.js b/gui/pages/Dashboard/Settings/Database.js
new file mode 100644
index 000000000..e3e4cae3f
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/Database.js
@@ -0,0 +1,182 @@
+import React, {useState, useEffect} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {createInternalId, formatTimeDifference, loadingTextEffect, preventDefault} from "@/utils/utils";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
+import knowledgeStyles from "@/pages/Content/Knowledge/Knowledge.module.css";
+import Image from "next/image";
+import {deleteVectorDB, getVectorDatabases} from "@/pages/api/DashboardService";
+
+export default function Database({sendDatabaseData}) {
+ const [vectorDB, setVectorDB] = useState([]);
+ const [isLoading, setIsLoading] = useState(true)
+ const [loadingText, setLoadingText] = useState("Loading Databases");
+ const [dropdown, setDropdown] = useState([]);
+ const [deleteModal, setDeleteModal] = useState(false);
+ const [selectedDatabase, setSelectedDatabase] = useState(null);
+
+ function fetchDatabases() {
+ setIsLoading(true);
+
+ getVectorDatabases()
+ .then((response) => {
+ const data = response.data || [];
+ setVectorDB(data);
+ setIsLoading(false);
+ })
+ .catch((error) => {
+ console.error('Error fetching vector databases:', error);
+ });
+ }
+
+ useEffect(() => {
+ loadingTextEffect('Loading Databases', setLoadingText, 500);
+ fetchDatabases();
+ }, []);
+
+ useEffect(() => {
+ setDropdown(Array(vectorDB.length).fill(false));
+ }, [vectorDB]);
+
+ const setDropdownWithIndex = (index, state) => {
+ setDropdown((prevDropdown) => {
+ const newDropdown = [...prevDropdown];
+ newDropdown[index] = state;
+ return newDropdown;
+ });
+ }
+
+ const openDeleteModal = (e, index) => {
+ e.stopPropagation();
+ setDeleteModal(true);
+ setSelectedDatabase(vectorDB[index]);
+ setDropdownWithIndex(index, false);
+ }
+
+ const deleteDatabase = (databaseId) => {
+ setDeleteModal(false);
+
+ deleteVectorDB(databaseId)
+ .then((response) => {
+ toast.success("Database deleted successfully", {autoClose: 1800});
+ fetchDatabases();
+ })
+ .catch((error) => {
+ toast.error("Unable to delete database", {autoClose: 1800});
+ console.error('Error fetching knowledge templates:', error);
+ });
+ }
+
+ return (<>
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/Dashboard/Settings/DatabaseDetails.js b/gui/pages/Dashboard/Settings/DatabaseDetails.js
new file mode 100644
index 000000000..0a7996ad4
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/DatabaseDetails.js
@@ -0,0 +1,207 @@
+import React, {useState, useEffect} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {removeTab, returnDatabaseIcon, setLocalStorageArray, preventDefault} from "@/utils/utils";
+import knowledgeStyles from "@/pages/Content/Knowledge/Knowledge.module.css";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
+import Image from "next/image";
+import {deleteVectorDB, getVectorDBDetails, updateVectorDB} from "@/pages/api/DashboardService";
+
+export default function DatabaseDetails({internalId, databaseId}) {
+ const [dropdown, setDropdown] = useState(false);
+ const [deleteModal, setDeleteModal] = useState(false);
+ const [collections, setCollections] = useState([]);
+ const [initialCollections, setInitialCollections] = useState([]);
+ const [hasChanges, setHasChanges] = useState(false);
+ const [databaseDetails, setDatabaseDetails] = useState([]);
+
+ useEffect(() => {
+ if (databaseId) {
+ getVectorDBDetails(databaseId)
+ .then((response) => {
+ const data = response.data || [];
+ setDatabaseDetails(data);
+ if (data) {
+ const localIndices = localStorage.getItem("db_details_collections_" + String(internalId));
+ const indices = data.indices || [];
+ setCollections(localIndices ? JSON.parse(localIndices) : indices);
+ setInitialCollections(indices);
+ }
+ })
+ .catch((error) => {
+ console.error('Error deleting database:', error);
+ });
+ }
+ }, [internalId]);
+
+ useEffect(() => {
+ if (JSON.stringify(collections) !== JSON.stringify(initialCollections)) {
+ setHasChanges(true);
+ } else {
+ setHasChanges(false);
+ }
+ }, [collections]);
+
+ const addCollection = () => {
+ setLocalStorageArray("db_details_collections_" + String(internalId), [...collections, 'collection name'], setCollections);
+ };
+
+ const handleCollectionChange = (index, newValue) => {
+ const updatedCollections = [...collections];
+ updatedCollections[index] = newValue;
+ setLocalStorageArray("db_details_collections_" + String(internalId), updatedCollections, setCollections);
+ };
+
+ const handleCollectionDelete = (index) => {
+ const updatedCollections = [...collections];
+ updatedCollections.splice(index, 1);
+ setLocalStorageArray("db_details_collections_" + String(internalId), updatedCollections, setCollections);
+ };
+
+ const deleteDatabase = () => {
+ setDeleteModal(false);
+
+ deleteVectorDB(databaseId)
+ .then((response) => {
+ toast.success("Database deleted successfully", {autoClose: 1800});
+ removeTab(databaseId, databaseDetails?.name, "Database", internalId);
+ })
+ .catch((error) => {
+ toast.error("Unable to delete database", {autoClose: 1800});
+ console.error('Error deleting database:', error);
+ });
+ }
+
+ const revertChanges = () => {
+ setCollections(initialCollections);
+ setHasChanges(false);
+ };
+
+ const updateChanges = () => {
+ updateVectorDB(databaseId, collections)
+ .then((response) => {
+ if (response.data.success) {
+ toast.success("Database updated successfully", {autoClose: 1800});
+ setInitialCollections(collections);
+ setHasChanges(false);
+ } else
+ toast.error("Unable to update database", {autoClose: 1800});
+ })
+ .catch((error) => {
+ toast.error("Unable to update database", {autoClose: 1800});
+ console.error('Error fetching knowledge templates:', error);
+ });
+ };
+
+ return (<>
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/Dashboard/Settings/Model.js b/gui/pages/Dashboard/Settings/Model.js
new file mode 100644
index 000000000..c348ca00e
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/Model.js
@@ -0,0 +1,157 @@
+import React, {useState, useEffect, useRef} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {getOrganisationConfig, updateOrganisationConfig, validateLLMApiKey} from "@/pages/api/DashboardService";
+import {EventBus} from "@/utils/eventBus";
+import {removeTab} from "@/utils/utils";
+import Image from "next/image";
+
+export default function Model({organisationId}) {
+ const [modelApiKey, setKey] = useState('');
+ const [temperature, setTemperature] = useState(0.5);
+ const [sourceDropdown, setSourceDropdown] = useState(false);
+ const sources = ['OpenAi', 'Google Palm'];
+ const [source, setSource] = useState(sources[0]);
+ const sourceRef = useRef(null);
+
+ function getKey(key) {
+ getOrganisationConfig(organisationId, key)
+ .then((response) => {
+ setKey(response.data.value);
+ })
+ .catch((error) => {
+ console.error('Error fetching project:', error);
+ });
+ }
+
+ function getSource(key) {
+ getOrganisationConfig(organisationId, key)
+ .then((response) => {
+ setSource(response.data.value);
+ })
+ .catch((error) => {
+ console.error('Error fetching project:', error);
+ });
+ }
+
+ useEffect(() => {
+ getKey("model_api_key");
+ getSource("model_source");
+ }, [organisationId]);
+
+ useEffect(() => {
+ function handleClickOutside(event) {
+ if (sourceRef.current && !sourceRef.current.contains(event.target)) {
+ setSourceDropdown(false)
+ }
+ }
+
+ document.addEventListener('mousedown', handleClickOutside);
+ return () => {
+ document.removeEventListener('mousedown', handleClickOutside);
+ };
+ }, []);
+
+ function updateKey(key, value) {
+ const configData = {"key": key, "value": value};
+ updateOrganisationConfig(organisationId, configData)
+ .then((response) => {
+ getKey("model_api_key");
+ EventBus.emit("keySet", {});
+ toast.success("Settings updated", {autoClose: 1800});
+ })
+ .catch((error) => {
+ console.error('Error fetching project:', error);
+ });
+ }
+
+ const handleModelApiKey = (event) => {
+ setKey(event.target.value);
+ };
+
+ const handleSourceSelect = (index) => {
+ setSource(sources[index]);
+ setSourceDropdown(false);
+ };
+
+ const saveSettings = () => {
+ if (modelApiKey === null || modelApiKey.replace(/\s/g, '') === '') {
+ toast.error("API key is empty", {autoClose: 1800});
+ return
+ }
+ validateLLMApiKey(source, modelApiKey)
+ .then((response) => {
+ if (response.data.status === "success") {
+ updateKey("model_api_key", modelApiKey);
+ updateKey("model_source", source);
+ } else {
+ toast.error("Invalid API key", {autoClose: 1800});
+ }
+ });
+ };
+
+ const handleTemperatureChange = (event) => {
+ setTemperature(event.target.value);
+ };
+
+ return (<>
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js
index 1fe3a88c0..297e6a3d5 100644
--- a/gui/pages/Dashboard/Settings/Settings.js
+++ b/gui/pages/Dashboard/Settings/Settings.js
@@ -1,163 +1,56 @@
-import React, {useState, useEffect, useRef} from 'react';
-import {ToastContainer, toast} from 'react-toastify';
+import React, {useState, useEffect} from 'react';
import 'react-toastify/dist/ReactToastify.css';
-import agentStyles from "@/pages/Content/Agents/Agents.module.css";
-import {getOrganisationConfig, updateOrganisationConfig,validateLLMApiKey} from "@/pages/api/DashboardService";
-import {EventBus} from "@/utils/eventBus";
-import {removeTab, setLocalStorageValue} from "@/utils/utils";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
import Image from "next/image";
+import Model from "@/pages/Dashboard/Settings/Model";
+import Database from "@/pages/Dashboard/Settings/Database";
-export default function Settings({organisationId}) {
- const [modelApiKey, setKey] = useState('');
- const [temperature, setTemperature] = useState(0.5);
- const [sourceDropdown, setSourceDropdown] = useState(false);
- const [source, setSource] = useState('OpenAi');
- const sourceRef = useRef(null);
- const sources = ['OpenAi', 'Google Palm']
-
- function getKey(key) {
- getOrganisationConfig(organisationId, key)
- .then((response) => {
- setKey(response.data.value);
- })
- .catch((error) => {
- console.error('Error fetching project:', error);
- });
- }
-
- function getSource(key) {
- getOrganisationConfig(organisationId, key)
- .then((response) => {
- setSource(response.data.value);
- })
- .catch((error) => {
- console.error('Error fetching project:', error);
- });
- }
+export default function Settings({organisationId, sendDatabaseData}) {
+ const [activeTab, setActiveTab] = useState('model');
useEffect(() => {
- getKey("model_api_key");
- getSource("model_source");
-
- function handleClickOutside(event) {
- if (sourceRef.current && !sourceRef.current.contains(event.target)) {
- setSourceDropdown(false);
- }
- }
-
- document.addEventListener('mousedown', handleClickOutside);
- return () => {
- document.removeEventListener('mousedown', handleClickOutside);
- };
- }, [organisationId]);
-
- function updateKey(key, value) {
- const configData = { "key": key, "value": value };
- return updateOrganisationConfig(organisationId, configData)
- .then((response) => {
- return response.data;
- })
- .catch((error) => {
- console.error('Error updating settings:', error);
- throw new Error('Failed to update settings');
- });
- }
-
- const handleModelApiKey = (event) => {
- setKey(event.target.value);
- };
-
- const handleSourceSelect = (index) => {
- setSource(sources[index]);
- setSourceDropdown(false);
- };
-
- const saveSettings = () => {
- if (modelApiKey === null || modelApiKey.replace(/\s/g, '') === '') {
- toast.error("API key is empty", { autoClose: 1800 });
- return;
+ const settings_tab = localStorage.getItem('settings_tab');
+ if (settings_tab) {
+ setActiveTab(settings_tab);
}
+ }, []);
- validateLLMApiKey(source, modelApiKey)
- .then((response) => {
- if (response.data.status === "success") {
- Promise.all([
- updateKey("model_api_key", modelApiKey),
- updateKey("model_source", source)
- ])
- .then(() => {
- toast.success("Settings updated", { autoClose: 1800 });
- })
- .catch((error) => {
- console.error('Error updating settings:', error);
- toast.error("Failed to update settings", { autoClose: 1800 });
- });
- } else {
- toast.error("Invalid API key", { autoClose: 1800 });
- }
- })
- .catch((error) => {
- console.error('Error validating API key:', error);
- toast.error("Failed to validate API key", { autoClose: 1800 });
- });
- };
-
- const handleTemperatureChange = (event) => {
- setTemperature(event.target.value);
+ const switchTab = (tab) => {
+ setActiveTab(tab);
+ localStorage.setItem('settings_tab', tab);
};
return (<>
-
-
-
-
-
-
Model Source
-
-
setSourceDropdown(!sourceDropdown)}
- style={{width: '100%'}}>
- {source}
+
+
+
+
+
+ switchTab('model')} className={styles.tab_button} style={activeTab === 'model' ? {
+ background: '#454254',
+ paddingRight: '15px'
+ } : {background: 'transparent', paddingRight: '15px'}}>
+ Model
+
- {sourceDropdown &&
- {sources.map((source, index) => (
-
handleSourceSelect(index)}
- style={{padding: '12px 14px', maxWidth: '100%'}}>
- {source}
-
))}
-
}
+
switchTab('database')} className={styles.tab_button}
+ style={activeTab === 'database' ? {
+ background: '#454254',
+ paddingRight: '15px'
+ } : {background: 'transparent', paddingRight: '15px'}}>
+ Database
+
-
- Open-AI/Palm API Key
-
-
- {/*
*/}
-
- removeTab(-3, "Settings", "Settings", 0)} className="secondary_button"
- style={{marginRight: '10px'}}>
- Cancel
-
-
- Update Changes
-
+ {activeTab === 'model' && }
+ {activeTab === 'database' && }
-
-
>)
}
\ No newline at end of file
diff --git a/gui/pages/Dashboard/SideBar.js b/gui/pages/Dashboard/SideBar.js
index f30912757..55a8413a8 100644
--- a/gui/pages/Dashboard/SideBar.js
+++ b/gui/pages/Dashboard/SideBar.js
@@ -1,7 +1,7 @@
import React, {useState} from 'react';
import Image from 'next/image';
import styles from './Dashboard.module.css';
-import {refreshUrl, openNewTab} from "@/utils/utils";
+import {openNewTab} from "@/utils/utils";
export default function SideBar({onSelectEvent}) {
const [sectionSelected, setSelection] = useState('');
@@ -10,7 +10,7 @@ export default function SideBar({onSelectEvent}) {
setSelection(value);
onSelectEvent(value);
if (value === 'apm') {
- openNewTab(-9, "APM", "APM");
+ openNewTab(-9, "APM", "APM", false);
}
};
@@ -36,11 +36,19 @@ export default function SideBar({onSelectEvent}) {
handleClick(sectionSelected !== 'apm' ? 'apm' : '')}
- className={`${styles.section} ${sectionSelected === 'apm' ? styles.selected : ''}`}>
+ className={styles.section}>
APM
+
+
handleClick(sectionSelected !== 'knowledge' ? 'knowledge' : '')}
+ className={`${styles.section} ${sectionSelected === 'knowledge' ? styles.selected : ''}`}>
+
+
Knowledge
+
+
);
-}
+}
\ No newline at end of file
diff --git a/gui/pages/_app.css b/gui/pages/_app.css
index 96f8f250b..33905c51d 100644
--- a/gui/pages/_app.css
+++ b/gui/pages/_app.css
@@ -399,7 +399,7 @@ input[type="range"]::-moz-range-track {
align-self: stretch;
margin-top: 2px;
max-height: 200px;
- overflow-y: scroll;
+ overflow-y: auto;
box-shadow: 0 2px 7px rgba(0,0,0,.4), 0 0 2px rgba(0,0,0,.22);
}
@@ -432,6 +432,11 @@ input[type="range"]::-moz-range-track {
border-radius: 8px;
}
+.custom_select_option:hover {
+ background: #3B3B49;
+ border-radius: 8px;
+}
+
.custom_select_option:active, .create_agent_dropdown_options:active {
background: #3B3B49;
border-radius: 8px;
@@ -532,6 +537,25 @@ p {
background-color: transparent;
}
+.three_dots_vertical {
+ width: 32px;
+ border: none;
+ font-size: 14px;
+ color: white;
+ height: 32px;
+ background: transparent;
+ padding: 15px;
+ display: -webkit-flex;
+ flex-direction: row;
+ align-items: center;
+ justify-content: center;
+ gap: 6px;
+ -webkit-box-orient: vertical;
+ -webkit-line-clamp: 1;
+ overflow: hidden;
+ text-overflow: ellipsis;
+}
+
.sideBarStyle {
height: 100vh;
width: 6.5vw;
@@ -647,6 +671,27 @@ p {
background: rgba(255, 255, 255, 0.10);
}
+.title_wrapper {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+}
+
+.database_box {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ border-radius: 8px;
+ background-color: #282437;
+}
+
+.index_options {
+ padding: 12px 14px;
+ max-width: 100%;
+ display: flex;
+ justify-content: space-between;
+}
+
.mt_6{margin-top: 6px;}
.mt_8{margin-top: 8px;}
.mt_10{margin-top: 10px;}
@@ -674,6 +719,7 @@ p {
.ml_6{margin-left: 6px;}
.ml_8{margin-left: 8px;}
.ml_10{margin-left: 10px;}
+.ml_12{margin-left: 12px;}
.ml_16{margin-left: 16px;}
.ml_20{margin-left: 20px;}
@@ -765,6 +811,14 @@ p {
flex-direction: column;
}
+.center_container{
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ width: 100%;
+}
+
.w_100{width: 100%}
.w_fit_content{width:fit-content}
.h_100{height: 100%}
@@ -998,4 +1052,81 @@ tr{
.tools_used_tooltip{
position: relative;
cursor: pointer;
-}
\ No newline at end of file
+}
+
+.image_class{
+ background: #FFFFFF80;
+ border-radius: 20px;
+}
+
+.loader {
+ font-size: 2px;
+ width: 1em;
+ height: 1em;
+ border-radius: 50%;
+ position: relative;
+ text-indent: -9999em;
+ -webkit-animation: load5 1.1s infinite ease;
+ animation: load5 1.1s infinite ease;
+ -webkit-transform: translateZ(0);
+ -ms-transform: translateZ(0);
+ transform: translateZ(0);
+}
+
+@-webkit-keyframes load5 {
+ 0%,
+ 100% {
+ box-shadow: 0em -2.6em 0em 0em #231f1f, 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.5), -1.8em -1.8em 0 0em rgba(35,31,31, 0.7);
+ }
+ 12.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.7), 1.8em -1.8em 0 0em #231f1f, 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.5);
+ }
+ 25% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.5), 1.8em -1.8em 0 0em rgba(35,31,31, 0.7), 2.5em 0em 0 0em #231f1f, 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 37.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.5), 2.5em 0em 0 0em rgba(35,31,31, 0.7), 1.75em 1.75em 0 0em #231f1f, 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 50% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.5), 1.75em 1.75em 0 0em rgba(35,31,31, 0.7), 0em 2.5em 0 0em #231f1f, -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 62.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.5), 0em 2.5em 0 0em rgba(35,31,31, 0.7), -1.8em 1.8em 0 0em #231f1f, -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 75% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.5), -1.8em 1.8em 0 0em rgba(35,31,31, 0.7), -2.6em 0em 0 0em #231f1f, -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 87.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.5), -2.6em 0em 0 0em rgba(35,31,31, 0.7), -1.8em -1.8em 0 0em #231f1f;
+ }
+}
+@keyframes load5 {
+ 0%,
+ 100% {
+ box-shadow: 0em -2.6em 0em 0em #231f1f, 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.5), -1.8em -1.8em 0 0em rgba(35,31,31, 0.7);
+ }
+ 12.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.7), 1.8em -1.8em 0 0em #231f1f, 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.5);
+ }
+ 25% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.5), 1.8em -1.8em 0 0em rgba(35,31,31, 0.7), 2.5em 0em 0 0em #231f1f, 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 37.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.5), 2.5em 0em 0 0em rgba(35,31,31, 0.7), 1.75em 1.75em 0 0em #231f1f, 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 50% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.5), 1.75em 1.75em 0 0em rgba(35,31,31, 0.7), 0em 2.5em 0 0em #231f1f, -1.8em 1.8em 0 0em rgba(35,31,31, 0.2), -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 62.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.5), 0em 2.5em 0 0em rgba(35,31,31, 0.7), -1.8em 1.8em 0 0em #231f1f, -2.6em 0em 0 0em rgba(35,31,31, 0.2), -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 75% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.5), -1.8em 1.8em 0 0em rgba(35,31,31, 0.7), -2.6em 0em 0 0em #231f1f, -1.8em -1.8em 0 0em rgba(35,31,31, 0.2);
+ }
+ 87.5% {
+ box-shadow: 0em -2.6em 0em 0em rgba(35,31,31, 0.2), 1.8em -1.8em 0 0em rgba(35,31,31, 0.2), 2.5em 0em 0 0em rgba(35,31,31, 0.2), 1.75em 1.75em 0 0em rgba(35,31,31, 0.2), 0em 2.5em 0 0em rgba(35,31,31, 0.2), -1.8em 1.8em 0 0em rgba(35,31,31, 0.5), -2.6em 0em 0 0em rgba(35,31,31, 0.7), -1.8em -1.8em 0 0em #231f1f;
+ }
+}
+
+.justify_space_between{justify-content: space-between}
+.cursor_pointer{cursor: pointer;}
diff --git a/gui/pages/_app.js b/gui/pages/_app.js
index 0b199d7b8..d10d5adf9 100644
--- a/gui/pages/_app.js
+++ b/gui/pages/_app.js
@@ -14,7 +14,7 @@ import {
validateAccessToken,
checkEnvironment,
addUser,
- installToolkitTemplate, installAgentTemplate
+ installToolkitTemplate, installAgentTemplate, installKnowledgeTemplate
} from "@/pages/api/DashboardService";
import {githubClientId} from "@/pages/api/apiConfig";
import {useRouter} from 'next/router';
@@ -33,7 +33,18 @@ export default function App() {
const [loadingText, setLoadingText] = useState("Initializing SuperAGI");
const router = useRouter();
const [showMarketplace, setShowMarketplace] = useState(false);
- const excludedKeys = ['repo_starred', 'popup_closed_time', 'twitter_toolkit_id', 'accessToken', 'agent_to_install', 'toolkit_to_install', 'google_calendar_toolkit_id', 'myLayoutKey'];
+ const excludedKeys = [
+ 'repo_starred',
+ 'popup_closed_time',
+ 'twitter_toolkit_id',
+ 'accessToken',
+ 'agent_to_install',
+ 'toolkit_to_install',
+ 'google_calendar_toolkit_id',
+ 'knowledge_to_install',
+ 'knowledge_index_to_install',
+ 'myLayoutKey'
+ ];
function fetchOrganisation(userId) {
getOrganisation(userId)
@@ -48,6 +59,20 @@ export default function App() {
const installFromMarketplace = () => {
const toolkitName = localStorage.getItem('toolkit_to_install') || null;
const agentTemplateId = localStorage.getItem('agent_to_install') || null;
+ const knowledgeTemplateName = localStorage.getItem('knowledge_to_install') || null;
+ const knowledgeIndexId = localStorage.getItem('knowledge_index_to_install') || null;
+
+ if (knowledgeTemplateName !== null && knowledgeIndexId !== null) {
+ installKnowledgeTemplate(knowledgeTemplateName, knowledgeIndexId)
+ .then((response) => {
+ toast.success("Template installed", {autoClose: 1800});
+ })
+ .catch((error) => {
+ console.error('Error installing template:', error);
+ });
+ localStorage.removeItem('knowledge_to_install');
+ localStorage.removeItem('knowledge_index_to_install');
+ }
if (toolkitName !== null) {
installToolkitTemplate(toolkitName)
@@ -221,4 +246,4 @@ export default function App() {
) : true}
);
-}
+}
\ No newline at end of file
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index b787d1933..32f994ff3 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -56,6 +56,10 @@ export const updateExecution = (executionId, executionData) => {
return api.put(`/agentexecutions/update/${executionId}`, executionData);
};
+export const editAgentTemplate = (agentTemplateId, agentTemplateData) => {
+ return api.put(`/agent_templates/update_agent_template/${agentTemplateId}`, agentTemplateData)
+};
+
export const addExecution = (executionData) => {
return api.post(`/agentexecutions/add`, executionData);
};
@@ -70,143 +74,212 @@ export const getLastActiveAgent = (projectId) => {
export const uploadFile = (agentId, formData) => {
return api.post(`/resources/add/${agentId}`, formData);
-}
+};
export const validateAccessToken = () => {
return api.get(`/validate-access-token`);
-}
+};
export const validateLLMApiKey = (model_source, model_api_key) => {
- return api.post(`/validate-llm-api-key`,{model_source, model_api_key});
-}
+ return api.post(`/validate-llm-api-key`, {model_source, model_api_key});
+};
+
export const checkEnvironment = () => {
return api.get(`/configs/get/env`);
-}
+};
export const getOrganisationConfig = (organisationId, key) => {
return api.get(`/configs/get/organisation/${organisationId}/key/${key}`);
-}
+};
export const updateOrganisationConfig = (organisationId, configData) => {
return api.post(`/configs/add/organisation/${organisationId}`, configData);
-}
+};
export const fetchAgentTemplateList = () => {
return api.get('/agent_templates/list?template_source=marketplace');
-}
+};
export const fetchAgentTemplateDetails = (templateId) => {
return api.get(`/agent_templates/get/${templateId}`);
-}
+};
export const getToolConfig = (toolKitName) => {
return api.get(`/tool_configs/get/toolkit/${toolKitName}`);
-}
+};
export const updateToolConfig = (toolKitName, configData) => {
return api.post(`/tool_configs/add/${toolKitName}`, configData);
-}
+};
export const fetchAgentTemplateListLocal = () => {
return api.get('/agent_templates/list?template_source=local');
-}
+};
export const saveAgentAsTemplate = (agentId) => {
return api.post(`/agent_templates/save_agent_as_template/${agentId}`);
-}
+};
export const fetchAgentTemplateConfig = (templateId) => {
return api.get(`/agent_templates/get/${templateId}?template_source=marketplace`);
-}
+};
export const installAgentTemplate = (templateId) => {
return api.post(`/agent_templates/download?agent_template_id=${templateId}`);
-}
+};
export const fetchAgentTemplateConfigLocal = (templateId) => {
return api.get(`/agent_templates/agent_config?agent_template_id=${templateId}`);
-}
+};
export const updatePermissions = (permissionId, data) => {
return api.put(`/agentexecutionpermissions/update/status/${permissionId}`, data)
-}
+};
export const deleteAgent = (agentId) => {
return api.put(`/agents/delete/${agentId}`)
-}
+};
export const authenticateGoogleCred = (toolKitId) => {
return api.get(`/google/get_google_creds/toolkit_id/${toolKitId}`);
-}
+};
export const authenticateTwitterCred = (toolKitId) => {
return api.get(`/twitter/get_twitter_creds/toolkit_id/${toolKitId}`);
-}
+};
export const sendTwitterCreds = (twitter_creds) => {
return api.post(`/twitter/send_twitter_creds/${twitter_creds}`);
-}
+};
export const sendGoogleCreds = (google_creds, toolkit_id) => {
return api.post(`/google/send_google_creds/toolkit_id/${toolkit_id}`, google_creds);
-}
+};
export const fetchToolTemplateList = () => {
return api.get(`/toolkits/get/list?page=0`);
-}
+};
+
+export const fetchKnowledgeTemplateList = () => {
+ return api.get(`/knowledges/get/list?page=0`);
+};
export const fetchToolTemplateOverview = (toolTemplateName) => {
return api.get(`/toolkits/marketplace/readme/${toolTemplateName}`);
-}
+};
export const installToolkitTemplate = (templateName) => {
return api.get(`/toolkits/get/install/${templateName}`);
-}
+};
export const getExecutionDetails = (executionId, agentId) => {
return api.get(`/agent_executions_configs/details/agent/${agentId}/agent_execution/${executionId}`);
-}
+};
export const stopSchedule = (agentId) => {
return api.post(`/agents/stop/schedule?agent_id=${agentId}`);
-}
+};
export const createAndScheduleRun = (requestData) => {
return api.post(`/agentexecutions/schedule`, requestData);
-}
+};
export const agentScheduleComponent = (agentId) => {
return api.get(`/agents/get/schedule_data/${agentId}`);
-}
+};
export const updateSchedule = (requestData) => {
return api.put(`/agents/edit/schedule`, requestData);
-}
+};
export const getDateTime = (agentId) => {
return api.get(`/agents/get/schedule_data/${agentId}`);
-}
+};
export const getMetrics = () => {
return api.get(`/analytics/metrics`)
-}
+};
export const getAllAgents = () => {
return api.get(`/analytics/agents/all`)
-}
+};
export const getAgentRuns = (agent_id) => {
return api.get(`analytics/agents/${agent_id}`);
-}
+};
export const getActiveRuns = () => {
return api.get(`analytics/runs/active`);
-}
+};
export const getToolsUsage = () => {
return api.get(`analytics/tools/used`);
-}
+};
export const getLlmModels = () => {
return api.get(`organisations/llm_models`);
-}
\ No newline at end of file
+};
+
+export const fetchVectorDBList = () => {
+ return api.get(`/vector_dbs/get/list`);
+};
+
+export const getVectorDatabases = () => {
+ return api.get(`/vector_dbs/user/list`);
+};
+
+export const getVectorDBDetails = (vectorDBId) => {
+ return api.get(`/vector_dbs/db/details/${vectorDBId}`);
+};
+
+export const deleteVectorDB = (vectorDBId) => {
+ return api.post(`/vector_dbs/delete/${vectorDBId}`);
+};
+
+export const updateVectorDB = (vectorDBId, newIndices) => {
+ return api.put(`/vector_dbs/update/vector_db/${vectorDBId}`, newIndices);
+};
+
+export const connectPinecone = (pineconeData) => {
+ return api.post(`/vector_dbs/connect/pinecone`, pineconeData);
+};
+
+export const connectQdrant = (qdrantData) => {
+ return api.post(`/vector_dbs/connect/qdrant`, qdrantData);
+};
+
+export const getKnowledge = () => {
+ return api.get(`/knowledges/user/list`);
+};
+
+export const getKnowledgeDetails = (knowledgeId) => {
+ return api.get(`/knowledges/user/get/details/${knowledgeId}`);
+};
+
+export const deleteCustomKnowledge = (knowledgeId) => {
+ return api.post(`/knowledges/delete/${knowledgeId}`);
+};
+
+export const deleteMarketplaceKnowledge = (knowledgeName) => {
+ return api.post(`/knowledges/uninstall/${knowledgeName}`);
+};
+
+export const addUpdateKnowledge = (knowledgeData) => {
+ return api.post(`/knowledges/add_or_update/data`, knowledgeData);
+};
+
+export const getValidIndices = () => {
+ return api.get(`/vector_db_indices/user/valid_indices`);
+};
+
+export const getValidMarketplaceIndices = (knowledgeName) => {
+ return api.get(`/vector_db_indices/marketplace/valid_indices/${knowledgeName}`);
+};
+
+export const fetchKnowledgeTemplateOverview = (knowledgeName) => {
+ return api.get(`/knowledges/marketplace/get/details/${knowledgeName}`);
+};
+
+export const installKnowledgeTemplate = (knowledgeName, indexId) => {
+ return api.get(`/knowledges/install/${knowledgeName}/index/${indexId}`);
+};
diff --git a/gui/public/images/books.svg b/gui/public/images/books.svg
new file mode 100644
index 000000000..6e58566d9
--- /dev/null
+++ b/gui/public/images/books.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/database.svg b/gui/public/images/database.svg
new file mode 100644
index 000000000..20b434477
--- /dev/null
+++ b/gui/public/images/database.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/instagram.png b/gui/public/images/instagram.png
new file mode 100644
index 000000000..c1a8ecb59
Binary files /dev/null and b/gui/public/images/instagram.png differ
diff --git a/gui/public/images/knowledeg_logo.png b/gui/public/images/knowledeg_logo.png
new file mode 100644
index 000000000..d2845fe42
Binary files /dev/null and b/gui/public/images/knowledeg_logo.png differ
diff --git a/gui/public/images/knowledge.svg b/gui/public/images/knowledge.svg
new file mode 100644
index 000000000..9c5987888
--- /dev/null
+++ b/gui/public/images/knowledge.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/model_light.svg b/gui/public/images/model_light.svg
new file mode 100644
index 000000000..ac6eec58f
--- /dev/null
+++ b/gui/public/images/model_light.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/pinecone.svg b/gui/public/images/pinecone.svg
new file mode 100644
index 000000000..c04cd6660
--- /dev/null
+++ b/gui/public/images/pinecone.svg
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/plus_symbol.svg b/gui/public/images/plus_symbol.svg
new file mode 100644
index 000000000..ff8ce19e7
--- /dev/null
+++ b/gui/public/images/plus_symbol.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/qdrant.svg b/gui/public/images/qdrant.svg
new file mode 100644
index 000000000..16d38464a
--- /dev/null
+++ b/gui/public/images/qdrant.svg
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/stack.svg b/gui/public/images/stack.svg
new file mode 100644
index 000000000..91f0085f5
--- /dev/null
+++ b/gui/public/images/stack.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/public/images/three_dots_vertical.svg b/gui/public/images/three_dots_vertical.svg
new file mode 100644
index 000000000..a0ba011ed
--- /dev/null
+++ b/gui/public/images/three_dots_vertical.svg
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
diff --git a/gui/utils/utils.js b/gui/utils/utils.js
index 4cca143fd..c18d40bc4 100644
--- a/gui/utils/utils.js
+++ b/gui/utils/utils.js
@@ -5,6 +5,26 @@ import {EventBus} from "@/utils/eventBus";
import JSZip from "jszip";
import moment from 'moment';
+const toolkitData = {
+ 'Jira Toolkit': '/images/jira_icon.svg',
+ 'Email Toolkit': '/images/gmail_icon.svg',
+ 'Google Calendar Toolkit': '/images/google_calender_icon.svg',
+ 'GitHub Toolkit': '/images/github_icon.svg',
+ 'Google Search Toolkit': '/images/google_search_icon.svg',
+ 'Searx Toolkit': '/images/searx_icon.svg',
+ 'Slack Toolkit': '/images/slack_icon.svg',
+ 'Web Scrapper Toolkit': '/images/webscraper_icon.svg',
+ 'Twitter Toolkit': '/images/twitter_icon.svg',
+ 'Google SERP Toolkit': '/images/google_serp_icon.svg',
+ 'File Toolkit': '/images/filemanager_icon.svg',
+ 'CodingToolkit': '/images/app-logo-light.png',
+ 'Thinking Toolkit': '/images/app-logo-light.png',
+ 'Image Generation Toolkit': '/images/app-logo-light.png',
+ 'DuckDuckGo Search Toolkit': '/images/duckduckgo_icon.png',
+ 'Instagram Toolkit': '/images/instagram.png',
+ 'Knowledge Search Toolkit': '/images/knowledeg_logo.png'
+};
+
export const getUserTimezone = () => {
return Intl.DateTimeFormat().resolvedOptions().timeZone;
}
@@ -86,7 +106,7 @@ export const formatBytes = (bytes, decimals = 2) => {
const formattedValue = parseFloat((bytes / Math.pow(k, i)).toFixed(decimals));
return `${formattedValue} ${sizes[i]}`;
-}
+};
export const downloadFile = (fileId, fileName = null) => {
const authToken = localStorage.getItem('accessToken');
@@ -181,7 +201,8 @@ export const refreshUrl = () => {
return;
}
- const urlWithoutToken = window.location.origin + window.location.pathname;
+ const {origin, pathname} = window.location;
+ const urlWithoutToken = origin + pathname;
window.history.replaceState({}, document.title, urlWithoutToken);
};
@@ -195,35 +216,35 @@ export const loadingTextEffect = (loadingText, setLoadingText, timer) => {
}, timer);
return () => clearInterval(interval)
-}
+};
-export const openNewTab = (id, name, contentType, hasInternalId) => {
+export const openNewTab = (id, name, contentType, hasInternalId = false) => {
EventBus.emit('openNewTab', {
element: {id: id, name: name, contentType: contentType, internalId: hasInternalId ? createInternalId() : 0}
});
-}
+};
export const removeTab = (id, name, contentType, internalId) => {
EventBus.emit('removeTab', {
element: {id: id, name: name, contentType: contentType, internalId: internalId}
});
-}
+};
export const setLocalStorageValue = (key, value, stateFunction) => {
stateFunction(value);
localStorage.setItem(key, value);
-}
+};
export const setLocalStorageArray = (key, value, stateFunction) => {
stateFunction(value);
const arrayString = JSON.stringify(value);
localStorage.setItem(key, arrayString);
-}
+};
const getInternalIds = () => {
const internal_ids = localStorage.getItem("agi_internal_ids");
return internal_ids ? JSON.parse(internal_ids) : [];
-}
+};
const removeAgentInternalId = (internalId) => {
let idsArray = getInternalIds();
@@ -259,8 +280,12 @@ const removeAgentInternalId = (internalId) => {
localStorage.removeItem("agent_time_unit_" + String(internalId));
localStorage.removeItem("agent_time_value_" + String(internalId));
localStorage.removeItem("agent_is_recurring_" + String(internalId));
+ localStorage.removeItem("is_agent_template_" + String(internalId));
+ localStorage.removeItem("agent_template_id_" + String(internalId));
+ localStorage.removeItem("agent_knowledge_" + String(internalId));
+ localStorage.removeItem("agent_knowledge_id_" + String(internalId));
}
-}
+};
const removeAddToolkitInternalId = (internalId) => {
let idsArray = getInternalIds();
@@ -271,7 +296,7 @@ const removeAddToolkitInternalId = (internalId) => {
localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
localStorage.removeItem('tool_github_' + String(internalId));
}
-}
+};
const removeToolkitsInternalId = (internalId) => {
let idsArray = getInternalIds();
@@ -283,6 +308,49 @@ const removeToolkitsInternalId = (internalId) => {
localStorage.removeItem('toolkit_tab_' + String(internalId));
localStorage.removeItem('api_configs_' + String(internalId));
}
+};
+
+const removeKnowledgeInternalId = (internalId) => {
+ let idsArray = getInternalIds();
+ const internalIdIndex = idsArray.indexOf(internalId);
+
+ if (internalIdIndex !== -1) {
+ idsArray.splice(internalIdIndex, 1);
+ localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
+ localStorage.removeItem('knowledge_name_' + String(internalId));
+ localStorage.removeItem('knowledge_description_' + String(internalId));
+ localStorage.removeItem('knowledge_index_' + String(internalId));
+ }
+}
+
+const removeAddDatabaseInternalId = (internalId) => {
+ let idsArray = getInternalIds();
+ const internalIdIndex = idsArray.indexOf(internalId);
+
+ if (internalIdIndex !== -1) {
+ idsArray.splice(internalIdIndex, 1);
+ localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
+ localStorage.removeItem('add_database_tab_' + String(internalId));
+ localStorage.removeItem('selected_db_' + String(internalId));
+ localStorage.removeItem('db_name_' + String(internalId));
+ localStorage.removeItem('db_collections_' + String(internalId));
+ localStorage.removeItem('pincone_api_' + String(internalId));
+ localStorage.removeItem('pinecone_env_' + String(internalId));
+ localStorage.removeItem('qdrant_api_' + String(internalId));
+ localStorage.removeItem('qdrant_url_' + String(internalId));
+ localStorage.removeItem('qdrant_port_' + String(internalId));
+ }
+}
+
+const removeDatabaseInternalId = (internalId) => {
+ let idsArray = getInternalIds();
+ const internalIdIndex = idsArray.indexOf(internalId);
+
+ if (internalIdIndex !== -1) {
+ idsArray.splice(internalIdIndex, 1);
+ localStorage.setItem('agi_internal_ids', JSON.stringify(idsArray));
+ localStorage.removeItem('db_details_collections_' + String(internalId));
+ }
}
export const resetLocalStorage = (contentType, internalId) => {
@@ -302,10 +370,25 @@ export const resetLocalStorage = (contentType, internalId) => {
case 'Toolkits':
removeToolkitsInternalId(internalId);
break;
+ case 'Knowledge':
+ removeKnowledgeInternalId(internalId);
+ break;
+ case 'Add_Knowledge':
+ removeKnowledgeInternalId(internalId);
+ break;
+ case 'Add_Database':
+ removeAddDatabaseInternalId(internalId);
+ break;
+ case 'Database':
+ removeDatabaseInternalId(internalId);
+ break;
+ case 'Settings':
+ localStorage.removeItem('settings_tab');
+ break;
default:
break;
}
-}
+};
export const createInternalId = () => {
let newId = 1;
@@ -326,56 +409,49 @@ export const createInternalId = () => {
}
return newId;
-}
+};
export const returnToolkitIcon = (toolkitName) => {
- const toolkitData = [
- {name: 'Jira Toolkit', imageSrc: '/images/jira_icon.svg'},
- {name: 'Email Toolkit', imageSrc: '/images/gmail_icon.svg'},
- {name: 'Google Calendar Toolkit', imageSrc: '/images/google_calender_icon.svg'},
- {name: 'GitHub Toolkit', imageSrc: '/images/github_icon.svg'},
- {name: 'Google Search Toolkit', imageSrc: '/images/google_search_icon.svg'},
- {name: 'Searx Toolkit', imageSrc: '/images/searx_icon.svg'},
- {name: 'Slack Toolkit', imageSrc: '/images/slack_icon.svg'},
- {name: 'Web Scrapper Toolkit', imageSrc: '/images/webscraper_icon.svg'},
- {name: 'Twitter Toolkit', imageSrc: '/images/twitter_icon.svg'},
- {name: 'Google SERP Toolkit', imageSrc: '/images/google_serp_icon.svg'},
- {name: 'File Toolkit', imageSrc: '/images/filemanager_icon.svg'},
- {name: 'CodingToolkit', imageSrc: '/images/app-logo-light.png'},
- {name: 'Image Generation Toolkit', imageSrc: '/images/app-logo-light.png'},
- {name: 'DuckDuckGo Search Toolkit', imageSrc: '/images/duckduckgo_icon.png'},
- ];
-
- const toolkit = toolkitData.find((tool) => tool.name === toolkitName);
- return toolkit ? toolkit.imageSrc : '/images/custom_tool.svg';
-}
+ return toolkitData[toolkitName] || '/images/custom_tool.svg';
+};
export const returnResourceIcon = (file) => {
- let fileIcon;
- const fileTypeIcons = {
- 'application/pdf': '/images/pdf_file.svg',
- 'application/txt': '/images/txt_file.svg',
- 'text/plain': '/images/txt_file.svg',
- };
-
- if (file.type.includes('image')) {
- fileIcon = '/images/img_file.svg';
- } else {
- fileIcon = fileTypeIcons[file.type] || '/images/default_file.svg';
+ const fileType = file.type;
+
+ switch (true) {
+ case fileType.includes('image'):
+ return '/images/img_file.svg';
+ case fileType === 'application/pdf':
+ return '/images/pdf_file.svg';
+ case fileType === 'application/txt' || fileType === 'text/plain':
+ return '/images/txt_file.svg';
+ default:
+ return '/images/default_file.svg';
}
+};
+
+export const returnDatabaseIcon = (database) => {
+ const dbTypeIcons = {
+ 'Pinecone': '/images/pinecone.svg',
+ 'Qdrant': '/images/qdrant.svg'
+ };
- return fileIcon;
+ return dbTypeIcons[database]
};
export const convertToTitleCase = (str) => {
- if (str === null || str === '') {
+ if (!str) {
return '';
}
const words = str.toLowerCase().split('_');
const capitalizedWords = words.map((word) => word.charAt(0).toUpperCase() + word.slice(1));
return capitalizedWords.join(' ');
-}
+};
+
+export const preventDefault = (e) => {
+ e.stopPropagation();
+};
export const excludedToolkits = () => {
return ["Thinking Toolkit", "Human Input Toolkit", "Resource Toolkit"];
diff --git a/install_tool_dependencies.sh b/install_tool_dependencies.sh
index 87dc10c2e..7f37ee224 100755
--- a/install_tool_dependencies.sh
+++ b/install_tool_dependencies.sh
@@ -4,7 +4,7 @@
pip install -r /app/requirements.txt
# Loop through the tools directories and install their requirements.txt if they exist
-for tool in /app/superagi/tools/* ; do
+for tool in /app/superagi/tools/* /app/superagi/tools/external_tools/* /app/superagi/tools/marketplace_tools/* ; do
if [ -d "$tool" ] && [ -f "$tool/requirements.txt" ]; then
echo "Installing requirements for tool: $(basename "$tool")"
pip install -r "$tool/requirements.txt"
diff --git a/main.py b/main.py
index 8fe2d2be7..7aa70033e 100644
--- a/main.py
+++ b/main.py
@@ -42,7 +42,12 @@
from superagi.controllers.user import router as user_router
from superagi.controllers.agent_execution_config import router as agent_execution_config
from superagi.controllers.analytics import router as analytics_router
-from superagi.helper.tool_helper import register_toolkits
+from superagi.controllers.knowledges import router as knowledges_router
+from superagi.controllers.knowledge_configs import router as knowledge_configs_router
+from superagi.controllers.vector_dbs import router as vector_dbs_router
+from superagi.controllers.vector_db_indices import router as vector_db_indices_router
+from superagi.controllers.marketplace_stats import router as marketplace_stats_router
+from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.openai import OpenAi
@@ -114,8 +119,13 @@
app.include_router(twitter_oauth_router, prefix="/twitter")
app.include_router(agent_execution_config, prefix="/agent_executions_configs")
app.include_router(analytics_router, prefix="/analytics")
-
app.include_router(google_oauth_router, prefix="/google")
+app.include_router(knowledges_router, prefix="/knowledges")
+app.include_router(knowledge_configs_router, prefix="/knowledge_configs")
+app.include_router(vector_dbs_router, prefix="/vector_dbs")
+app.include_router(vector_db_indices_router, prefix="/vector_db_indices")
+app.include_router(marketplace_stats_router, prefix="/marketplace")
+
# in production you can use Settings management
# from pydantic to get secret key from .env
@@ -128,6 +138,8 @@ def create_access_token(email, Authorize: AuthJWT = Depends()):
expiry_time_hours = superagi.config.config.get_config("JWT_EXPIRY")
if type(expiry_time_hours) == str:
expiry_time_hours = int(expiry_time_hours)
+ if expiry_time_hours is None:
+ expiry_time_hours = 200
expires = timedelta(hours=expiry_time_hours)
access_token = Authorize.create_access_token(subject=email, expires_time=expires)
return access_token
@@ -180,14 +192,14 @@ def build_single_step_agent():
agent_workflow_id=agent_workflow.id, output_type="tools",
step_type="TRIGGER",
history_enabled=True,
- completion_prompt="Determine which next tool to use, and respond using the format specified above:")
+ completion_prompt="Determine which next tool to use,and respond with only valid JSON conforming to the above schema")
session.add(first_step)
session.commit()
else:
first_step.prompt = output["prompt"]
first_step.variables = str(output["variables"])
first_step.output_type = "tools"
- first_step.completion_prompt = "Determine which next tool to use, and respond using the format specified above:"
+ first_step.completion_prompt = "Determine which next tool to use,and respond with only valid JSON conforming to the above schema"
session.commit()
first_step.next_step_id = first_step.id
session.commit()
@@ -314,17 +326,26 @@ def build_action_based_agents():
workflow_step2.next_step_id = workflow_step2.id
session.commit()
- def check_toolkit_registration():
+ def register_toolkit_for_all_organisation():
organizations = session.query(Organisation).all()
for organization in organizations:
register_toolkits(session, organization)
logger.info("Successfully registered local toolkits for all Organisations!")
+ def register_toolkit_for_master_organisation():
+ marketplace_organisation_id = superagi.config.config.get_config("MARKETPLACE_ORGANISATION_ID")
+ marketplace_organisation = session.query(Organisation).filter(
+ Organisation.id == marketplace_organisation_id).first()
+ if marketplace_organisation is not None:
+ register_marketplace_toolkits(session, marketplace_organisation)
+
build_single_step_agent()
build_task_based_agents()
build_action_based_agents()
if env != "PROD":
- check_toolkit_registration()
+ register_toolkit_for_all_organisation()
+ else:
+ register_toolkit_for_master_organisation()
session.close()
@@ -396,7 +417,6 @@ def github_auth_handler(code: str = Query(...), Authorize: AuthJWT = Depends()):
db.session.add(user)
db.session.commit()
jwt_token = create_access_token(user_email, Authorize)
-
redirect_url_success = f"{frontend_url}?access_token={jwt_token}"
return RedirectResponse(url=redirect_url_success)
else:
diff --git a/migrations/versions/71e3980d55f5_knowledge_and_vector_dbs.py b/migrations/versions/71e3980d55f5_knowledge_and_vector_dbs.py
new file mode 100644
index 000000000..ebc0660b0
--- /dev/null
+++ b/migrations/versions/71e3980d55f5_knowledge_and_vector_dbs.py
@@ -0,0 +1,87 @@
+"""Knowledge and Vector dbs
+
+Revision ID: 71e3980d55f5
+Revises: cac478732572
+Create Date: 2023-07-26 07:18:06.492832
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '71e3980d55f5'
+down_revision = 'cac478732572'
+branch_labels = None
+depends_on = None
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('knowledge_configs',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('knowledge_id', sa.Integer(), nullable=False),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Text(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('knowledges',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('name', sa.String(), nullable=False),
+ sa.Column('description', sa.String(), nullable=True),
+ sa.Column('vector_db_index_id', sa.Integer(), nullable=True),
+ sa.Column('organisation_id', sa.Integer(), nullable=True),
+ sa.Column('contributed_by', sa.String(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('marketplace_stats',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('reference_id', sa.Integer(), nullable=True),
+ sa.Column('reference_name', sa.String(), nullable=True),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Integer(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('vector_db_configs',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('vector_db_id', sa.Integer(), nullable=False),
+ sa.Column('key', sa.String(), nullable=True),
+ sa.Column('value', sa.Text(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('vector_db_indices',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('name', sa.String(), nullable=False),
+ sa.Column('vector_db_id', sa.Integer(), nullable=True),
+ sa.Column('dimensions', sa.Integer(), nullable=True),
+ sa.Column('state', sa.String(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ op.create_table('vector_dbs',
+ sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+ sa.Column('name', sa.String(), nullable=False),
+ sa.Column('db_type', sa.String(), nullable=True),
+ sa.Column('organisation_id', sa.Integer(), nullable=True),
+ sa.Column('created_at', sa.DateTime(), nullable=True),
+ sa.Column('updated_at', sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('vector_dbs')
+ op.drop_table('vector_db_indices')
+ op.drop_table('vector_db_configs')
+ op.drop_table('knowledges')
+ op.drop_table('knowledge_configs')
\ No newline at end of file
diff --git a/superagi/apm/tools_handler.py b/superagi/apm/tools_handler.py
index 71b0937ce..72048d529 100644
--- a/superagi/apm/tools_handler.py
+++ b/superagi/apm/tools_handler.py
@@ -4,14 +4,22 @@
from sqlalchemy.orm import Session
from superagi.models.events import Event
+from superagi.models.tool import Tool
+from superagi.models.toolkit import Toolkit
class ToolsHandler:
-
def __init__(self, session: Session, organisation_id: int):
self.session = session
self.organisation_id = organisation_id
+ def get_tool_and_toolkit(self):
+ tools_and_toolkits = self.session.query(
+ Tool.name.label('tool_name'), Toolkit.name.label('toolkit_name')).join(
+ Toolkit, Tool.toolkit_id == Toolkit.id).all()
+
+ return {item.tool_name: item.toolkit_name for item in tools_and_toolkits}
+
def calculate_tool_usage(self) -> List[Dict[str, int]]:
tool_usage = []
tool_used_subquery = self.session.query(
@@ -32,15 +40,18 @@ def calculate_tool_usage(self) -> List[Dict[str, int]]:
query = self.session.query(
agent_count.c.tool_name,
agent_count.c.unique_agents,
- total_usage.c.total_usage
+ total_usage.c.total_usage,
).join(total_usage, total_usage.c.tool_name == agent_count.c.tool_name)
+ tool_and_toolkit = self.get_tool_and_toolkit()
+
result = query.all()
tool_usage = [{
'tool_name': row.tool_name,
'unique_agents': row.unique_agents,
- 'total_usage': row.total_usage
+ 'total_usage': row.total_usage,
+ 'toolkit': tool_and_toolkit.get(row.tool_name, None)
} for row in result]
return tool_usage
\ No newline at end of file
diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py
index b287b05bc..aa4056895 100644
--- a/superagi/controllers/agent.py
+++ b/superagi/controllers/agent.py
@@ -30,6 +30,7 @@
import json
from superagi.models.toolkit import Toolkit
+from superagi.models.knowledges import Knowledges
from sqlalchemy import func
# from superagi.types.db import AgentOut, AgentIn
@@ -287,6 +288,15 @@ def create_and_schedule_agent(agent_config_schedule: AgentConfigSchedule,
if agent_schedule.id is None:
raise HTTPException(status_code=500, detail="Failed to schedule agent")
+ agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
+ organisation = agent.get_agent_organisation(db.session)
+
+ EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_config.name,
+ 'model': agent_config.model}, db_agent.id,
+ organisation.id if organisation else 0)
+
+ db.session.commit()
+
return {
"id": db_agent.id,
"name": db_agent.name,
@@ -471,13 +481,18 @@ def get_agent_configuration(agent_id: int,
AgentExecution.agent_id == agent_id).scalar()
total_tokens = db.session.query(func.sum(AgentExecution.num_of_tokens)).filter(
AgentExecution.agent_id == agent_id).scalar()
-
+
+ name = ""
# Construct the JSON response
response = {result.key: result.value for result in results}
+ if 'knowledge' in response.keys() and response['knowledge'] != 'None':
+ knowledge = db.session.query(Knowledges).filter(Knowledges.id == response['knowledge']).first()
+ name = knowledge.name if knowledge is not None else ""
response = merge(response, {"name": agent.name, "description": agent.description,
# Query the AgentConfiguration table for the speci
"goal": eval(response["goal"]),
"instruction": eval(response.get("instruction", '[]')),
+ "knowledge_name": name,
"calls": total_calls,
"tokens": total_tokens,
"constraints": eval(response.get("constraints")),
@@ -511,7 +526,8 @@ def delete_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
db_agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
db_agent_executions = db.session.query(AgentExecution).filter(AgentExecution.agent_id == agent_id).all()
-
+ db_agent_schedule = db.session.query(AgentSchedule).filter(AgentSchedule.agent_id == agent_id, AgentSchedule.status == "SCHEDULED").first()
+
if not db_agent or db_agent.is_deleted:
raise HTTPException(status_code=404, detail="agent not found")
@@ -522,4 +538,8 @@ def delete_agent(agent_id: int, Authorize: AuthJWT = Depends(check_auth)):
for db_agent_execution in db_agent_executions:
db_agent_execution.status = "TERMINATED"
+ if db_agent_schedule:
+ # Updating the schedule status to STOPPED
+ db_agent_schedule.status = "STOPPED"
+
db.session.commit()
diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py
index 140622aee..100bee8d0 100644
--- a/superagi/controllers/agent_template.py
+++ b/superagi/controllers/agent_template.py
@@ -13,6 +13,7 @@
from superagi.models.agent_template_config import AgentTemplateConfig
from superagi.models.agent_workflow import AgentWorkflow
from superagi.models.tool import Tool
+import json
# from superagi.types.db import AgentTemplateIn, AgentTemplateOut
router = APIRouter()
@@ -144,6 +145,59 @@ def update_agent_template(agent_template_id: int,
return db_agent_template
+@router.put("/update_agent_template/{agent_template_id}", status_code=200)
+def edit_agent_template(agent_template_id: int,
+ updated_agent_configs: dict,
+ organisation=Depends(get_user_organisation)):
+
+ """
+ Update the details of an agent template.
+
+ Args:
+ agent_template_id (int): The ID of the agent template to update.
+ edited_agent_configs (dict): The updated agent configurations.
+ organisation (Depends): Dependency to get the user organisation.
+
+ Returns:
+ HTTPException (status_code=200): If the agent gets successfully edited.
+
+ Raises:
+ HTTPException (status_code=404): If the agent template is not found.
+ """
+
+ db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
+ AgentTemplate.id == agent_template_id).first()
+ if db_agent_template is None:
+ raise HTTPException(status_code=404, detail="Agent Template not found")
+
+ db_agent_template.name = updated_agent_configs["name"]
+ db_agent_template.description = updated_agent_configs["description"]
+
+ db.session.commit()
+
+ agent_config_values = updated_agent_configs.get('agent_configs', {})
+
+ for key, value in agent_config_values.items():
+ if isinstance(value, (list, dict)):
+ value = json.dumps(value)
+ config = db.session.query(AgentTemplateConfig).filter(
+ AgentTemplateConfig.agent_template_id == agent_template_id,
+ AgentTemplateConfig.key == key
+ ).first()
+
+ if config is not None:
+ config.value = value
+ else:
+ new_config = AgentTemplateConfig(
+ agent_template_id=agent_template_id,
+ key=key,
+ value= value
+ )
+ db.session.add(new_config)
+
+ db.session.commit()
+ db.session.flush()
+
@router.post("/save_agent_as_template/{agent_id}")
def save_agent_as_template(agent_id: str,
diff --git a/superagi/controllers/knowledge_configs.py b/superagi/controllers/knowledge_configs.py
new file mode 100644
index 000000000..86a2abc63
--- /dev/null
+++ b/superagi/controllers/knowledge_configs.py
@@ -0,0 +1,15 @@
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends, Query, status
+from fastapi import APIRouter
+from superagi.config.config import get_config
+from superagi.helper.auth import check_auth
+from superagi.models.knowledge_configs import KnowledgeConfigs
+from fastapi_jwt_auth import AuthJWT
+
+router = APIRouter()
+
+@router.get("/marketplace/details/{knowledge_id}")
+def get_marketplace_knowledge_configs(knowledge_id: int, Authorize: AuthJWT = Depends(check_auth)):
+ knowledge_configs = db.session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
+ return knowledge_configs
+
diff --git a/superagi/controllers/knowledges.py b/superagi/controllers/knowledges.py
new file mode 100644
index 000000000..3afac761f
--- /dev/null
+++ b/superagi/controllers/knowledges.py
@@ -0,0 +1,166 @@
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends, Query, status
+from fastapi import APIRouter
+from datetime import datetime
+from superagi.config.config import get_config
+from superagi.helper.auth import get_user_organisation
+from superagi.models.knowledges import Knowledges
+from superagi.models.marketplace_stats import MarketPlaceStats
+from superagi.models.knowledge_configs import KnowledgeConfigs
+from superagi.models.vector_db_indices import VectordbIndices
+from superagi.models.vector_dbs import Vectordbs
+from superagi.helper.s3_helper import S3Helper
+from superagi.models.vector_db_configs import VectordbConfigs
+from superagi.vector_store.vector_factory import VectorFactory
+from superagi.vector_embeddings.vector_embedding_factory import VectorEmbeddingFactory
+from superagi.helper.time_helper import get_time_difference
+
+router = APIRouter()
+
+@router.get("/get/list")
+def get_knowledge_list(
+ page: int = Query(None, title="Page Number"),
+ organisation = Depends(get_user_organisation)
+):
+ """
+ Get Marketplace Knowledge list.
+
+ Args:
+ page (int, optional): The page number for pagination. Defaults to None.
+
+ Returns:
+ dict: The response containing the marketplace list.
+
+ """
+
+ marketplace_knowledges = Knowledges.fetch_marketplace_list(page)
+ marketplace_knowledges_with_install = Knowledges.get_knowledge_install_details(db.session, marketplace_knowledges, organisation)
+ for knowledge in marketplace_knowledges_with_install:
+ knowledge["install_number"] = MarketPlaceStats.get_knowledge_installation_number(knowledge["id"])
+ return marketplace_knowledges_with_install
+
+@router.get("/marketplace/list/{page}")
+def get_marketplace_knowledge_list(page: int = 0):
+ organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
+ page_size = 30
+
+ # Apply search filter if provided
+ query = db.session.query(Knowledges).filter(Knowledges.organisation_id == organisation_id)
+
+ if page < 0:
+ knowledges = query.all()
+ # Paginate the results
+ knowledges = query.offset(page * page_size).limit(page_size).all()
+
+ return knowledges
+
+@router.get("/user/list")
+def get_user_knowledge_list(organisation = Depends(get_user_organisation)):
+ marketplace_knowledges = Knowledges.fetch_marketplace_list(page=-1)
+ user_knowledge_list = Knowledges.get_organisation_knowledges(db.session, organisation)
+ for user_knowledge in user_knowledge_list:
+ if user_knowledge["name"] in [knowledge.name for knowledge in marketplace_knowledges]:
+ user_knowledge["is_marketplace"] = True
+ else:
+ user_knowledge["is_marketplace"] = False
+ return user_knowledge_list
+
+@router.get("/marketplace/get/details/{knowledge_name}")
+def get_knowledge_details(knowledge_name: str):
+ knowledge_data = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
+ knowledge_config_data = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge_data["id"])
+ knowledge_data_with_config = knowledge_data | knowledge_config_data
+ knowledge_data_with_config["install_number"] = MarketPlaceStats.get_knowledge_installation_number(knowledge_data_with_config["id"])
+ update_time = str(knowledge_data_with_config["updated_at"])
+ update_time = datetime.strptime(update_time, "%Y-%m-%dT%H:%M:%S.%f")
+ knowledge_data_with_config["updated_at"] = datetime.strftime(update_time, '%d %B %Y')
+ return knowledge_data_with_config
+
+@router.get("/marketplace/details/{knowledge_name}")
+def get_marketplace_knowledge_details(knowledge_name: str):
+ organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
+ knowledge_details = db.session.query(Knowledges).filter(Knowledges.name == knowledge_name, Knowledges.organisation_id == organisation_id).first()
+ return knowledge_details
+
+@router.get("/user/get/details/{knowledge_id}")
+def get_user_knowledge_details(knowledge_id: int):
+ knowledge_data = Knowledges.get_knowledge_from_id(db.session, knowledge_id)
+ vector_database_index = VectordbIndices.get_vector_index_from_id(db.session, knowledge_data.vector_db_index_id)
+ vector_database = Vectordbs.get_vector_db_from_id(db.session, vector_database_index.vector_db_id)
+ knowledge = {
+ "name": knowledge_data.name,
+ "description": knowledge_data.description,
+ "vector_database_index": {
+ "id": vector_database_index.id,
+ "name": vector_database_index.name
+ },
+ "vector_database": vector_database.name,
+ "installation_type": vector_database_index.state
+ }
+ knowledge_config = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(db.session, knowledge_id)
+ knowledge_data_with_config = knowledge | knowledge_config
+ return knowledge_data_with_config
+
+@router.post("/add_or_update/data")
+def add_update_user_knowledge(knowledge_data: dict, organisation = Depends(get_user_organisation)):
+ knowledge_data["organisation_id"] = organisation.id
+ knowledge_data["contributed_by"] = organisation.name
+ knowledge = Knowledges.add_update_knowledge(db.session, knowledge_data)
+ if not knowledge:
+ raise HTTPException(status_code=404, detail="Knowledge not found")
+ return {"id": knowledge.id}
+
+
+@router.post("/delete/{knowledge_id}")
+def delete_user_knowledge(knowledge_id: int):
+ try:
+ Knowledges.delete_knowledge(db.session, knowledge_id)
+ except:
+ raise HTTPException(status_code=404, detail="Knowledge not found")
+
+@router.get("/install/{knowledge_name}/index/{vector_db_index_id}")
+def install_selected_knowledge(knowledge_name: str, vector_db_index_id: int, organisation = Depends(get_user_organisation)):
+ vector_db_index = VectordbIndices.get_vector_index_from_id(db.session, vector_db_index_id)
+ selected_knowledge = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
+ selected_knowledge_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(selected_knowledge['id'])
+ file_chunks = S3Helper().get_json_file(selected_knowledge_config["file_path"])
+ vector = Vectordbs.get_vector_db_from_id(db.session, vector_db_index.vector_db_id)
+ db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector.id)
+ upsert_data = VectorEmbeddingFactory.build_vector_storage(vector.db_type, file_chunks).get_vector_embeddings_from_chunks()
+ try:
+ vector_db_storage = VectorFactory.build_vector_storage(vector.db_type, vector_db_index.name, **db_creds)
+ vector_db_storage.add_embeddings_to_vector_db(upsert_data)
+ except Exception as err:
+ raise HTTPException(status_code=400, detail=err)
+ selected_knowledge_data = {
+ "id": -1,
+ "name": selected_knowledge["name"],
+ "description": selected_knowledge["description"],
+ "index_id": vector_db_index_id,
+ "organisation_id": organisation.id,
+ "contributed_by": selected_knowledge["contributed_by"],
+ }
+ new_knowledge = Knowledges.add_update_knowledge(db.session, selected_knowledge_data)
+ removable_key = 'file_path'
+ selected_knowledge_config.pop(removable_key)
+ configs = selected_knowledge_config
+ KnowledgeConfigs.add_update_knowledge_config(db.session, new_knowledge.id, configs)
+ VectordbIndices.update_vector_index_state(db.session, vector_db_index_id, "Marketplace")
+ install_number = MarketPlaceStats.get_knowledge_installation_number(selected_knowledge["id"])
+ MarketPlaceStats.update_knowledge_install_number(db.session, selected_knowledge["id"], int(install_number) + 1)
+
+@router.post("/uninstall/{knowledge_name}")
+def uninstall_selected_knowledge(knowledge_name: str, organisation = Depends(get_user_organisation)):
+ knowledge = db.session.query(Knowledges).filter(Knowledges.name == knowledge_name, Knowledges.organisation_id == organisation.id).first()
+ knowledge_config = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(db.session, knowledge.id)
+ vector_ids = eval(knowledge_config["vector_ids"])
+ vector_db_index = VectordbIndices.get_vector_index_from_id(db.session, knowledge.vector_db_index_id)
+ vector = Vectordbs.get_vector_db_from_id(db.session, vector_db_index.vector_db_id)
+ db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector.id)
+ try:
+ vector_db_storage = VectorFactory.build_vector_storage(vector.db_type, vector_db_index.name, **db_creds)
+ vector_db_storage.delete_embeddings_from_vector_db(vector_ids)
+ except Exception as err:
+ raise HTTPException(status_code=400, detail=err)
+ KnowledgeConfigs.delete_knowledge_config(db.session, knowledge.id)
+ Knowledges.delete_knowledge(db.session, knowledge.id)
\ No newline at end of file
diff --git a/superagi/controllers/marketplace_stats.py b/superagi/controllers/marketplace_stats.py
new file mode 100644
index 000000000..1fca38a2f
--- /dev/null
+++ b/superagi/controllers/marketplace_stats.py
@@ -0,0 +1,17 @@
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends, Query, status
+from fastapi import APIRouter
+from superagi.config.config import get_config
+from superagi.models.marketplace_stats import MarketPlaceStats
+from superagi.models.vector_dbs import Vectordbs
+
+router = APIRouter()
+
+@router.get("/knowledge/downloads/{knowledge_id}")
+def count_knowledge_downloads(knowledge_id: int):
+ download_number = db.session.query(MarketPlaceStats).filter(MarketPlaceStats.reference_id == knowledge_id, MarketPlaceStats.reference_name == "KNOWLEDGE", MarketPlaceStats.key == "download_count").first()
+ if download_number is None:
+ downloads = 0
+ else:
+ downloads = download_number.value
+ return downloads
\ No newline at end of file
diff --git a/superagi/controllers/toolkit.py b/superagi/controllers/toolkit.py
index e04200cd4..f6ae7aa23 100644
--- a/superagi/controllers/toolkit.py
+++ b/superagi/controllers/toolkit.py
@@ -6,10 +6,12 @@
from fastapi_sqlalchemy import db
from superagi.config.config import get_config
from superagi.helper.auth import get_user_organisation
-from superagi.helper.tool_helper import get_readme_content_from_code_link, download_tool,process_files,add_tool_to_json
+from superagi.helper.tool_helper import get_readme_content_from_code_link, download_tool, process_files, \
+ add_tool_to_json
from superagi.helper.github_helper import GithubHelper
from superagi.models.organisation import Organisation
from superagi.models.tool import Tool
+from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from superagi.types.common import GitHubLinkRequest
@@ -20,7 +22,7 @@
# marketplace_url = "http://localhost:8001/"
-#For internal use
+# For internal use
@router.get("/marketplace/list/{page}")
def get_marketplace_toolkits(
page: int = 0,
@@ -51,7 +53,8 @@ def get_marketplace_toolkits(
toolkit.updated_at = toolkit.updated_at.strftime('%d-%b-%Y').upper()
return toolkits
-#For internal use
+
+# For internal use
@router.get("/marketplace/details/{toolkit_name}")
def get_marketplace_toolkit_detail(toolkit_name: str):
"""
@@ -66,10 +69,14 @@ def get_marketplace_toolkit_detail(toolkit_name: str):
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
- toolkit = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation_id, Toolkit.name == toolkit_name).first()
+ toolkit = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation_id,
+ Toolkit.name == toolkit_name).first()
+ toolkit.tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
+ toolkit.configs = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit.id).all()
return toolkit
-#For internal use
+
+# For internal use
@router.get("/marketplace/readme/{toolkit_name}")
def get_marketplace_toolkit_readme(toolkit_name: str):
"""
@@ -93,7 +100,8 @@ def get_marketplace_toolkit_readme(toolkit_name: str):
raise HTTPException(status_code=404, detail='ToolKit not found')
return get_readme_content_from_code_link(toolkit.tool_code_link)
-#For internal use
+
+# For internal use
@router.get("/marketplace/tools/{toolkit_name}")
def get_marketplace_toolkit_tools(toolkit_name: str):
"""
@@ -111,7 +119,8 @@ def get_marketplace_toolkit_tools(toolkit_name: str):
"""
organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
- toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name, Toolkit.organisation_id == organisation_id).first()
+ toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
+ Toolkit.organisation_id == organisation_id).first()
if not toolkit:
raise HTTPException(status_code=404, detail="ToolKit not found")
tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).first()
@@ -132,15 +141,18 @@ def install_toolkit_from_marketplace(toolkit_name: str,
dict: A message indicating the successful installation of the tool kit.
"""
-
# Check if the tool kit exists
toolkit = Toolkit.fetch_marketplace_detail(search_str="details",
toolkit_name=toolkit_name)
- # download_and_install_tool(GitHubLinkRequest(github_link=toolkit['tool_code_link']),
- # organisation=organisation)
- if not GithubHelper.validate_github_link(toolkit['tool_code_link']):
- raise HTTPException(status_code=400, detail="Invalid Github link")
- add_tool_to_json(toolkit['tool_code_link'])
+ db_toolkit = Toolkit.add_or_update(session=db.session, name=toolkit['name'], description=toolkit['description'],
+ tool_code_link=toolkit['tool_code_link'], organisation_id=organisation.id,
+ show_toolkit=toolkit['show_toolkit'])
+ for tool in toolkit['tools']:
+ Tool.add_or_update(session=db.session, tool_name=tool['name'], description=tool['description'],
+ folder_name=tool['folder_name'], class_name=tool['class_name'], file_name=tool['file_name'],
+ toolkit_id=db_toolkit.id)
+ for config in toolkit['configs']:
+ ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'])
return {"message": "ToolKit installed successfully"}
@@ -228,6 +240,7 @@ def get_installed_toolkit_readme(toolkit_name: str, organisation: Organisation =
readme_content = get_readme_content_from_code_link(toolkit.tool_code_link)
return readme_content
+
# Following APIs will be used to get marketplace related information
@router.get("/get")
def handle_marketplace_operations(
@@ -289,4 +302,4 @@ def get_installed_toolkit_list(organisation: Organisation = Depends(get_user_org
toolkit_tools = db.session.query(Tool).filter(Tool.toolkit_id == toolkit.id).all()
toolkit.tools = toolkit_tools
- return toolkits
\ No newline at end of file
+ return toolkits
diff --git a/superagi/controllers/types/agent_with_config.py b/superagi/controllers/types/agent_with_config.py
index 59f234b00..b5b19e4d9 100644
--- a/superagi/controllers/types/agent_with_config.py
+++ b/superagi/controllers/types/agent_with_config.py
@@ -19,3 +19,4 @@ class AgentConfigInput(BaseModel):
LTM_DB: str
max_iterations: int
user_timezone: Optional[str]
+ knowledge: Optional[int]
diff --git a/superagi/controllers/vector_db_indices.py b/superagi/controllers/vector_db_indices.py
new file mode 100644
index 000000000..667bc953e
--- /dev/null
+++ b/superagi/controllers/vector_db_indices.py
@@ -0,0 +1,45 @@
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends, Query
+from fastapi import APIRouter
+from superagi.helper.auth import get_user_organisation
+from superagi.models.vector_dbs import Vectordbs
+from superagi.models.vector_db_indices import VectordbIndices
+from superagi.models.knowledges import Knowledges
+from superagi.models.knowledge_configs import KnowledgeConfigs
+
+router = APIRouter()
+
+@router.get("/marketplace/valid_indices/{knowledge_name}")
+def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(get_user_organisation)):
+ vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
+ knowledge = Knowledges.fetch_knowledge_details_marketplace(knowledge_name)
+ knowledge_with_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge['id'])
+ pinecone = []
+ qdrant = []
+ for vector_db in vector_dbs:
+ indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
+ for index in indices:
+ data = {"id": index.id, "name": index.name}
+ data["is_valid_dimension"] = True if index.dimensions == int(knowledge_with_config["dimensions"]) else False
+ data["is_valid_state"] = True if index.state != "Custom" else False
+ if vector_db.db_type == "Pinecone":
+ pinecone.append(data)
+ if vector_db.db_type == "Qdrant":
+ qdrant.append(data)
+ return {"pinecone": pinecone, "qdrant": qdrant}
+
+@router.get("/user/valid_indices")
+def get_user_valid_indices(organisation = Depends(get_user_organisation)):
+ vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
+ pinecone = []
+ qdrant = []
+ for vector_db in vector_dbs:
+ indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
+ for index in indices:
+ data = {"id": index.id, "name": index.name}
+ data["is_valid_state"] = True if index.state == "Custom" else False
+ if vector_db.db_type == "Pinecone":
+ pinecone.append(data)
+ if vector_db.db_type == "Qdrant":
+ qdrant.append(data)
+ return {"pinecone": pinecone, "qdrant": qdrant}
\ No newline at end of file
diff --git a/superagi/controllers/vector_dbs.py b/superagi/controllers/vector_dbs.py
new file mode 100644
index 000000000..66e49ff63
--- /dev/null
+++ b/superagi/controllers/vector_dbs.py
@@ -0,0 +1,128 @@
+from fastapi_sqlalchemy import db
+from fastapi import HTTPException, Depends
+from fastapi import APIRouter
+from superagi.config.config import get_config
+from datetime import datetime
+from superagi.helper.time_helper import get_time_difference
+from superagi.models.vector_dbs import Vectordbs
+from superagi.helper.auth import get_user_organisation
+from superagi.models.vector_db_configs import VectordbConfigs
+from superagi.models.vector_db_indices import VectordbIndices
+from superagi.vector_store.vector_factory import VectorFactory
+from superagi.models.knowledges import Knowledges
+
+router = APIRouter()
+
+@router.get("/get/list")
+def get_vector_db_list():
+ marketplace_vector_dbs = Vectordbs.fetch_marketplace_list()
+ return marketplace_vector_dbs
+
+@router.get("/marketplace/list")
+def get_marketplace_vectordb_list():
+ organisation_id = int(get_config("MARKETPLACE_ORGANISATION_ID"))
+ vector_dbs = db.session.query(Vectordbs).filter(Vectordbs.organisation_id == organisation_id).all()
+ return vector_dbs
+
+@router.get("/user/list")
+def get_user_connected_vector_db_list(organisation = Depends(get_user_organisation)):
+ vector_db_list = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
+ if vector_db_list:
+ for vector in vector_db_list:
+ vector.updated_at = get_time_difference(vector.updated_at, str(datetime.now()))
+ return vector_db_list
+
+@router.get("/db/details/{vector_db_id}")
+def get_vector_db_details(vector_db_id: int):
+ vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id)
+ vector_db_data = {
+ "id": vector_db.id,
+ "name": vector_db.name,
+ "db_type": vector_db.db_type
+ }
+ vector_db_config = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector_db_id)
+ vector_db_with_config = vector_db_data | vector_db_config
+ indices = db.session.query(VectordbIndices).filter(VectordbIndices.vector_db_id == vector_db_id).all()
+ vector_indices = []
+ for index in indices:
+ vector_indices.append(index.name)
+ vector_db_with_config["indices"] = vector_indices
+ return vector_db_with_config
+
+@router.post("/delete/{vector_db_id}")
+def delete_vector_db(vector_db_id: int):
+ try:
+ vector_indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db_id)
+ for vector_index in vector_indices:
+ Knowledges.delete_knowledge_from_vector_index(db.session, vector_index.id)
+ VectordbIndices.delete_vector_db_index(db.session, vector_index.id)
+ VectordbConfigs.delete_vector_db_configs(db.session, vector_db_id)
+ Vectordbs.delete_vector_db(db.session, vector_db_id)
+ except:
+ raise HTTPException(status_code=404, detail="VectorDb not found")
+
+@router.post("/connect/pinecone")
+def connect_pinecone_vector_db(data: dict, organisation = Depends(get_user_organisation)):
+ db_creds = {
+ "api_key": data["api_key"],
+ "environment": data["environment"]
+ }
+ for collection in data["collections"]:
+ try:
+ vector_db_storage = VectorFactory.build_vector_storage("pinecone", collection, **db_creds)
+ db_connect_for_index = vector_db_storage.get_index_stats()
+ index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
+ except:
+ raise HTTPException(status_code=400, detail="Unable to connect Pinecone")
+ pinecone_db = Vectordbs.add_vector_db(db.session, data["name"], "Pinecone", organisation)
+ VectordbConfigs.add_vector_db_config(db.session, pinecone_db.id, db_creds)
+ for collection in data["collections"]:
+ VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, db_connect_for_index["dimensions"], index_state)
+ return {"id": pinecone_db.id, "name": pinecone_db.name}
+
+@router.post("/connect/qdrant")
+def connect_qdrant_vector_db(data: dict, organisation = Depends(get_user_organisation)):
+ db_creds = {
+ "api_key": data["api_key"],
+ "url": data["url"],
+ "port": data["port"]
+ }
+ for collection in data["collections"]:
+ try:
+ vector_db_storage = VectorFactory.build_vector_storage("qdrant", collection, **db_creds)
+ db_connect_for_index = vector_db_storage.get_index_stats()
+ index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
+ except:
+ raise HTTPException(status_code=400, detail="Unable to connect Qdrant")
+ qdrant_db = Vectordbs.add_vector_db(db.session, data["name"], "Qdrant", organisation)
+ VectordbConfigs.add_vector_db_config(db.session, qdrant_db.id, db_creds)
+ for collection in data["collections"]:
+ VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, db_connect_for_index["dimensions"], index_state)
+
+ return {"id": qdrant_db.id, "name": qdrant_db.name}
+
+@router.put("/update/vector_db/{vector_db_id}")
+def update_vector_db(new_indices: list, vector_db_id: int):
+ vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id)
+ existing_indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db_id)
+ existing_index_names = []
+ for index in existing_indices:
+ if index.name not in new_indices:
+ VectordbIndices.delete_vector_db_index(db.session, vector_index_id=index.id)
+ existing_index_names.append(index.name)
+ existing_index_names = set(existing_index_names)
+ new_indices_names = set(new_indices)
+ added_indices = new_indices_names - existing_index_names
+ for index in added_indices:
+ db_creds = VectordbConfigs.get_vector_db_config_from_db_id(db.session, vector_db_id)
+ try:
+ vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, index, **db_creds)
+ vector_db_index_stats = vector_db_storage.get_index_stats()
+ index_state = "Custom" if vector_db_index_stats["vector_count"] > 0 else "None"
+ except:
+ raise HTTPException(status_code=400, detail="Unable to update vector db")
+ VectordbIndices.add_vector_index(db.session, index, vector_db_id, vector_db_index_stats["dimensions"], index_state)
+
+
+
+
\ No newline at end of file
diff --git a/superagi/helper/resource_helper.py b/superagi/helper/resource_helper.py
index cd39eeb14..961cbd998 100644
--- a/superagi/helper/resource_helper.py
+++ b/superagi/helper/resource_helper.py
@@ -1,16 +1,17 @@
+import os
+
from superagi.config.config import get_config
+from superagi.helper.s3_helper import S3Helper
+from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.models.resource import Resource
-import os
-import datetime
-from superagi.lib.logger import logger
from superagi.types.storage_types import StorageType
class ResourceHelper:
@classmethod
- def make_written_file_resource(cls, file_name: str, agent: Agent, agent_execution: AgentExecution):
+ def make_written_file_resource(cls, file_name: str, agent: Agent, agent_execution: AgentExecution, session):
"""
Function to create a Resource object for a written file.
@@ -18,6 +19,7 @@ def make_written_file_resource(cls, file_name: str, agent: Agent, agent_executio
file_name (str): The name of the file.
agent (Agent): Agent related to resource.
agent_execution(AgentExecution): Agent Execution related to a resource
+ session (Session): The database session.
Returns:
Resource: The Resource object.
@@ -46,13 +48,36 @@ def make_written_file_resource(cls, file_name: str, agent: Agent, agent_executio
logger.info("make_written_file_resource:", final_path)
if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
file_path = "resources" + file_path
- resource = Resource(name=file_name, path=file_path, storage_type=storage_type.value,
- size=file_size,
- type=file_type,
- channel="OUTPUT",
- agent_id=agent.id,
- agent_execution_id=agent_execution.id)
- return resource
+ existing_resource = session.query(Resource).filter_by(
+ name=file_name,
+ path=file_path,
+ storage_type=storage_type.value,
+ type=file_type,
+ channel="OUTPUT",
+ agent_id=agent.id,
+ agent_execution_id=agent_execution.id
+ ).first()
+
+ if existing_resource:
+ # Update the existing resource attributes
+ existing_resource.size = file_size
+ session.commit()
+ session.flush()
+ return existing_resource
+ else:
+ resource = Resource(
+ name=file_name,
+ path=file_path,
+ storage_type=storage_type.value,
+ size=file_size,
+ type=file_type,
+ channel="OUTPUT",
+ agent_id=agent.id,
+ agent_execution_id=agent_execution.id
+ )
+ session.add(resource)
+ session.commit()
+ return resource
@classmethod
def get_formatted_agent_level_path(cls, agent: Agent, path) -> object:
@@ -118,6 +143,15 @@ def get_agent_write_resource_path(cls, file_name: str, agent: Agent, agent_execu
final_path = root_dir + file_name
return final_path
+ @staticmethod
+ def __check_file_path_exists(path):
+ return (StorageType.get_storage_type(get_config("STORAGE_TYPE",
+ StorageType.FILE.value)) is StorageType.S3 and
+ not S3Helper().check_file_exists_in_s3(path)) or (
+ StorageType.get_storage_type(
+ get_config("STORAGE_TYPE", StorageType.FILE.value)) is StorageType.FILE
+ and not os.path.exists(path))
+
@classmethod
def get_agent_read_resource_path(cls, file_name, agent: Agent, agent_execution: AgentExecution):
"""Get agent resource path to read files i.e. both input and output directory
@@ -128,13 +162,13 @@ def get_agent_read_resource_path(cls, file_name, agent: Agent, agent_execution:
agent (Agent): The agent corresponding to resource.
agent_execution (AgentExecution): The agent execution corresponding to the resource.
"""
- output_root_dir = ResourceHelper.get_root_output_dir()
final_path = ResourceHelper.get_root_input_dir() + file_name
if "{agent_id}" in final_path:
final_path = ResourceHelper.get_formatted_agent_level_path(
agent=agent,
path=final_path)
- if final_path is None or not os.path.exists(final_path):
+ output_root_dir = ResourceHelper.get_root_output_dir()
+ if final_path is None or cls.__check_file_path_exists(final_path):
if output_root_dir is not None:
final_path = ResourceHelper.get_root_output_dir() + file_name
if "{agent_id}" in final_path:
diff --git a/superagi/helper/s3_helper.py b/superagi/helper/s3_helper.py
index 2dee0a84e..8fe481a2f 100644
--- a/superagi/helper/s3_helper.py
+++ b/superagi/helper/s3_helper.py
@@ -2,6 +2,8 @@
from superagi.config.config import get_config
from fastapi import HTTPException
from superagi.lib.logger import logger
+import json
+
class S3Helper:
def __init__(self):
@@ -9,12 +11,22 @@ def __init__(self):
Initialize the S3Helper class.
Using the AWS credentials from the configuration file, create a boto3 client.
"""
- self.s3 = boto3.client(
+ self.s3 = S3Helper.__get_s3_client()
+ self.bucket_name = get_config("BUCKET_NAME")
+
+ @classmethod
+ def __get_s3_client(cls):
+ """
+ Get an S3 client.
+
+ Returns:
+ s3 (S3Helper): The S3Helper object.
+ """
+ return boto3.client(
's3',
aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
)
- self.bucket_name = get_config("BUCKET_NAME")
def upload_file(self, file, path):
"""
@@ -35,3 +47,32 @@ def upload_file(self, file, path):
logger.info("File uploaded to S3 successfully!")
except:
raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
+
+ def check_file_exists_in_s3(self, file_path):
+ response = self.s3.list_objects_v2(Bucket=get_config("BUCKET_NAME"), Prefix="resources" + file_path)
+ return 'Contents' in response
+
+ def read_from_s3(self, file_path):
+ file_path = "resources" + file_path
+ logger.info(f"Reading file from s3: {file_path}")
+ response = self.s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=file_path)
+ if response['ResponseMetadata']['HTTPStatusCode'] == 200:
+ return response['Body'].read().decode('utf-8')
+ raise Exception(f"Error read_from_s3: {response}")
+
+ def get_json_file(self, path):
+ """
+ Get a JSON file from S3.
+ Args:
+ path (str): The path to the JSON file.
+ Raises:
+ HTTPException: If the AWS credentials are not found.
+ Returns:
+ dict: The JSON file.
+ """
+ try:
+ obj = self.s3.get_object(Bucket=self.bucket_name, Key=path)
+ s3_response = obj['Body'].read().decode('utf-8')
+ return json.loads(s3_response)
+ except:
+ raise HTTPException(status_code=500, detail="AWS credentials not found. Check your configuration.")
diff --git a/superagi/helper/tool_helper.py b/superagi/helper/tool_helper.py
index e8f3d7f1b..b9eaf0865 100644
--- a/superagi/helper/tool_helper.py
+++ b/superagi/helper/tool_helper.py
@@ -118,13 +118,17 @@ def load_module_from_file(file_path):
return module
-def init_tools(folder_path, session, tool_name_to_toolkit):
+def init_tools(folder_paths, session, tool_name_to_toolkit):
# Iterate over all subfolders
- for folder_name in os.listdir(folder_path):
- folder_dir = os.path.join(folder_path, folder_name)
- # Iterate over all files in the subfolder
- if os.path.isdir(folder_dir):
- # sys.path.append(os.path.abspath('superagi/tools/email'))
+ for folder_path in folder_paths:
+ if not os.path.exists(folder_path):
+ continue
+ for folder_name in os.listdir(folder_path):
+ folder_dir = os.path.join(folder_path, folder_name)
+ # Iterate over all files in the subfolder
+ if not os.path.isdir(folder_dir):
+ continue
+ # sys.path.append(os.path.abspath('superagi/tools/email'))
sys.path.append(folder_dir)
for file_name in os.listdir(folder_dir):
file_path = os.path.join(folder_dir, file_name)
@@ -147,15 +151,19 @@ def update_base_tool_class_info(classes, file_name, folder_name, session, tool_n
description=tool_description)
-def init_toolkits(code_link, existing_toolkits, folder_path, organisation, session):
+def init_toolkits(code_link, existing_toolkits, folder_paths, organisation, session):
tool_name_to_toolkit = {}
new_toolkits = []
# Iterate over all subfolders
- for folder_name in os.listdir(folder_path):
- folder_dir = os.path.join(folder_path, folder_name)
+ for folder_path in folder_paths:
+ if not os.path.exists(folder_path):
+ continue
+ for folder_name in os.listdir(folder_path):
+ folder_dir = os.path.join(folder_path, folder_name)
- if os.path.isdir(folder_dir):
- # sys.path.append(os.path.abspath('superagi/tools/email'))
+ if not os.path.isdir(folder_dir):
+ continue
+ # sys.path.append(os.path.abspath('superagi/tools/email'))
sys.path.append(folder_dir)
# Iterate over all files in the subfolder
for file_name in os.listdir(folder_dir):
@@ -214,14 +222,16 @@ def update_base_toolkit_info(classes, code_link, folder_name, new_toolkits, orga
return tool_name_to_toolkit
-def process_files(folder_path, session, organisation, code_link=None):
+def process_files(folder_paths, session, organisation, code_link=None):
existing_toolkits = session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
- tool_name_to_toolkit = init_toolkits(code_link, existing_toolkits, folder_path, organisation, session)
- init_tools(folder_path, session, tool_name_to_toolkit)
+ tool_name_to_toolkit = init_toolkits(code_link, existing_toolkits, folder_paths, organisation, session)
+ init_tools(folder_paths, session, tool_name_to_toolkit)
def get_readme_content_from_code_link(tool_code_link):
+ if tool_code_link is None:
+ return None
parsed_url = urlparse(tool_code_link)
path_parts = parsed_url.path.split("/")
@@ -240,13 +250,18 @@ def get_readme_content_from_code_link(tool_code_link):
def register_toolkits(session, organisation):
- folder_path = get_config("TOOLS_DIR")
- if folder_path is None:
- folder_path = "superagi/tools"
+ tool_paths = ["superagi/tools", "superagi/tools/external_tools"]
+ # if get_config("ENV", "DEV") == "PROD":
+ # tool_paths.append("superagi/tools/marketplace_tools")
if organisation is not None:
- process_files(folder_path, session, organisation)
- logger.info(f"Toolkits Registered Successfully for Organisation ID : {organisation.id}!")
+ process_files(tool_paths, session, organisation)
+ logger.info(f"Toolkits Registered Successfully for Organisation ID : {organisation.id}!")
+def register_marketplace_toolkits(session, organisation):
+ tool_paths = ["superagi/tools", "superagi/tools/external_tools","superagi/tools/marketplace_tools"]
+ if organisation is not None:
+ process_files(tool_paths, session, organisation)
+ logger.info(f"Marketplace Toolkits Registered Successfully for Organisation ID : {organisation.id}!")
def extract_repo_name(repo_link):
# Extract the repository name from the link
@@ -273,10 +288,12 @@ def add_tool_to_json(repo_link):
def handle_tools_import():
- folder_path = get_config("TOOLS_DIR")
- if folder_path is None:
- folder_path = "superagi/tools"
- for folder_name in os.listdir(folder_path):
- folder_dir = os.path.join(folder_path, folder_name)
- if os.path.isdir(folder_dir):
- sys.path.append(folder_dir)
\ No newline at end of file
+ print("Handling tools import")
+ tool_paths = ["superagi/tools", "superagi/tools/marketplace_tools", "superagi/tools/external_tools"]
+ for tool_path in tool_paths:
+ if not os.path.exists(tool_path):
+ continue
+ for folder_name in os.listdir(tool_path):
+ folder_dir = os.path.join(tool_path, folder_name)
+ if os.path.isdir(folder_dir):
+ sys.path.append(folder_dir)
diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py
index 751944b33..3c2aad9a7 100644
--- a/superagi/jobs/agent_executor.py
+++ b/superagi/jobs/agent_executor.py
@@ -1,10 +1,12 @@
import importlib
+import os
from datetime import datetime, timedelta
from sqlalchemy.orm import sessionmaker
import superagi.worker
from superagi.agent.super_agi import SuperAgi
+from superagi.apm.event_handler import EventHandler
from superagi.config.config import get_config
from superagi.helper.encyption_helper import decrypt_data
from superagi.lib.logger import logger
@@ -32,7 +34,7 @@
from superagi.types.vector_store_types import VectorStoreType
from superagi.vector_store.embedding.openai import OpenAiEmbedding
from superagi.vector_store.vector_factory import VectorFactory
-from superagi.apm.event_handler import EventHandler
+
# from superagi.helper.tool_helper import get_tool_config_by_key
engine = connect_db()
@@ -82,18 +84,18 @@ def create_object(tool, session):
object: The object of the agent usable tool.
"""
file_name = AgentExecutor.validate_filename(filename=tool.file_name)
+ tool_paths = ["superagi/tools", "superagi/tools/external_tools", "superagi/tools/marketplace_tools"]
+ for tool_path in tool_paths:
+ if os.path.exists(os.path.join(os.getcwd(), tool_path) + '/' + tool.folder_name):
+ tools_dir = tool_path
+ break
- tools_dir = get_config("TOOLS_DIR")
- if tools_dir is None:
- tools_dir = "superagi/tools"
parsed_tools_dir = tools_dir.rstrip("/")
module_name = ".".join(parsed_tools_dir.split("/") + [tool.folder_name, file_name])
# module_name = f"superagi.tools.{folder_name}.{file_name}"
-
# Load the module dynamically
module = importlib.import_module(module_name)
-
# Get the class from the loaded module
obj_class = getattr(module, tool.class_name)
@@ -138,7 +140,7 @@ def get_embedding(cls, model_source, model_api_key):
return None
@staticmethod
- def get_organisation(agent_execution,session):
+ def get_organisation(agent_execution, session):
"""
Get the model API key from the agent execution.
@@ -155,7 +157,6 @@ def get_organisation(agent_execution,session):
return organisation
-
def execute_next_action(self, agent_execution_id):
"""
Execute the next action of the agent execution.
@@ -197,7 +198,12 @@ def execute_next_action(self, agent_execution_id):
db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
db_agent_execution.status = "ITERATION_LIMIT_EXCEEDED"
session.commit()
- EventHandler(session=session).create_event('run_iteration_limit_crossed', {'agent_execution_id':db_agent_execution.id,'name': db_agent_execution.name,'tokens_consumed':db_agent_execution.num_of_tokens,"calls":db_agent_execution.num_of_calls}, db_agent_execution.agent_id, organisation.id)
+ EventHandler(session=session).create_event('run_iteration_limit_crossed',
+ {'agent_execution_id': db_agent_execution.id,
+ 'name': db_agent_execution.name,
+ 'tokens_consumed': db_agent_execution.num_of_tokens,
+ "calls": db_agent_execution.num_of_calls},
+ db_agent_execution.agent_id, organisation.id)
logger.info("ITERATION_LIMIT_CROSSED")
return "ITERATION_LIMIT_CROSSED"
@@ -271,7 +277,11 @@ def execute_next_action(self, agent_execution_id):
db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
db_agent_execution.status = "COMPLETED"
session.commit()
- EventHandler(session=session).create_event('run_completed', {'agent_execution_id':db_agent_execution.id,'name': db_agent_execution.name,'tokens_consumed':db_agent_execution.num_of_tokens,"calls":db_agent_execution.num_of_calls}, db_agent_execution.agent_id, organisation.id)
+ EventHandler(session=session).create_event('run_completed', {'agent_execution_id': db_agent_execution.id,
+ 'name': db_agent_execution.name,
+ 'tokens_consumed': db_agent_execution.num_of_tokens,
+ "calls": db_agent_execution.num_of_calls},
+ db_agent_execution.agent_id, organisation.id)
elif response["result"] == "WAITING_FOR_PERMISSION":
db_agent_execution = session.query(AgentExecution).filter(AgentExecution.id == agent_execution_id).first()
db_agent_execution.status = "WAITING_FOR_PERMISSION"
@@ -364,15 +374,15 @@ def handle_wait_for_permission(self, agent_execution, spawned_agent, session):
def get_agent_resource_summary(self, agent_id: int, session: Session, model_llm_source: str, default_summary: str):
if ModelSourceType.GooglePalm.value in model_llm_source:
return
- ResourceSummarizer(session=session).generate_agent_summary(agent_id=agent_id,generate_all=True)
+ ResourceSummarizer(session=session).generate_agent_summary(agent_id=agent_id, generate_all=True)
agent_config_resource_summary = session.query(AgentConfiguration). \
filter(AgentConfiguration.agent_id == agent_id,
AgentConfiguration.key == "resource_summary").first()
resource_summary = agent_config_resource_summary.value if agent_config_resource_summary is not None else default_summary
return resource_summary
- def check_for_resource(self,agent_id: int, session: Session):
- resource = session.query(Resource).filter(Resource.agent_id == agent_id,Resource.channel == 'INPUT').first()
+ def check_for_resource(self, agent_id: int, session: Session):
+ resource = session.query(Resource).filter(Resource.agent_id == agent_id, Resource.channel == 'INPUT').first()
if resource is None:
return False
return True
diff --git a/superagi/jobs/scheduling_executor.py b/superagi/jobs/scheduling_executor.py
index 701c33b21..f09afdd0e 100644
--- a/superagi/jobs/scheduling_executor.py
+++ b/superagi/jobs/scheduling_executor.py
@@ -60,7 +60,6 @@ def execute_scheduled_agent(self, agent_id: int, name: str):
organisation = agent.get_agent_organisation(session)
model = session.query(AgentConfiguration.value).filter(AgentConfiguration.agent_id == agent_id).filter(AgentConfiguration.key == 'model').first()[0]
EventHandler(session=session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, agent_id, organisation.id if organisation else 0),
- EventHandler(session=session).create_event('agent_created', {'agent_name': agent.name, 'model': model}, agent_id, organisation.id if organisation else 0)
session.commit()
diff --git a/superagi/models/agent.py b/superagi/models/agent.py
index cca1d6e30..f979e2442 100644
--- a/superagi/models/agent.py
+++ b/superagi/models/agent.py
@@ -84,6 +84,7 @@ def fetch_configuration(cls, session, agent_id: int):
"memory_window": None,
"max_iterations": None,
"is_deleted": agent.is_deleted,
+ "knowledge": None
}
if not agent_configurations:
return parsed_config
@@ -105,7 +106,7 @@ def eval_agent_config(cls, key, value):
"""
- if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", "resource_summary"]:
+ if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB", "resource_summary", "knowledge"]:
return value
elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]:
return int(value)
@@ -163,7 +164,8 @@ def create_agent_with_config(cls, db, agent_with_config):
"permission_type": agent_with_config.permission_type,
"LTM_DB": agent_with_config.LTM_DB,
"max_iterations": agent_with_config.max_iterations,
- "user_timezone": agent_with_config.user_timezone
+ "user_timezone": agent_with_config.user_timezone,
+ "knowledge": agent_with_config.knowledge,
}
agent_configurations = [
diff --git a/superagi/models/agent_template.py b/superagi/models/agent_template.py
index da332a408..4060c0bad 100644
--- a/superagi/models/agent_template.py
+++ b/superagi/models/agent_template.py
@@ -97,7 +97,7 @@ def main_keys(cls):
"""
keys_to_fetch = ["goal", "instruction", "agent_type", "constraints", "tools", "exit", "iteration_interval", "model",
- "permission_type", "LTM_DB", "max_iterations"]
+ "permission_type", "LTM_DB", "max_iterations", "knowledge"]
return keys_to_fetch
@classmethod
@@ -193,7 +193,7 @@ def eval_agent_config(cls, key, value):
if key in ["name", "description", "agent_type", "exit", "model", "permission_type", "LTM_DB"]:
return value
- elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval"]:
+ elif key in ["project_id", "memory_window", "max_iterations", "iteration_interval", "knowledge"]:
return int(value)
elif key == "goal" or key == "constraints" or key == "instruction":
return eval(value)
diff --git a/superagi/models/knowledge_configs.py b/superagi/models/knowledge_configs.py
new file mode 100644
index 000000000..8f7e41aab
--- /dev/null
+++ b/superagi/models/knowledge_configs.py
@@ -0,0 +1,74 @@
+from sqlalchemy import Column, Integer, Text, String
+import requests
+from superagi.models.base_model import DBBaseModel
+marketplace_url = "https://app.superagi.com/api"
+# marketplace_url = "http://localhost:8001"
+
+
+class KnowledgeConfigs(DBBaseModel):
+ """
+ Knowledge related configurations such as model, data_type, tokenizer, chunk_size, chunk_overlap, text_splitter, etc. are stored here.
+ Attributes:
+ id (int): The unique identifier of the knowledge configuration.
+ knowledge_id (int): The identifier of the associated knowledge.
+ key (str): The key of the configuration setting.
+ value (str): The value of the configuration setting.
+ """
+
+ __tablename__ = 'knowledge_configs'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ knowledge_id = Column(Integer)
+ key = Column(String)
+ value = Column(Text)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Knowledge Configuration object.
+ Returns:
+ str: String representation of the Knowledge Configuration.
+ """
+ return f"KnowledgeConfiguration(id={self.id}, knowledge_id={self.knowledge_id}, key={self.key}, value={self.value})"
+
+ @classmethod
+ def fetch_knowledge_config_details_marketplace(cls, knowledge_id: int):
+ headers = {'Content-Type': 'application/json'}
+ response = requests.get(
+ marketplace_url + f"/knowledge_configs/marketplace/details/{str(knowledge_id)}",
+ headers=headers, timeout=10)
+ if response.status_code == 200:
+ knowledge_config_data = response.json()
+ configs = {}
+ for knowledge_config in knowledge_config_data:
+ configs[knowledge_config["key"]] = knowledge_config["value"]
+ return configs
+ else:
+ return []
+
+ @classmethod
+ def add_update_knowledge_config(cls, session, knowledge_id, knowledge_configs):
+ for key, value in knowledge_configs.items():
+ config = KnowledgeConfigs(knowledge_id=knowledge_id, key=key, value=value)
+ session.add(config)
+ session.commit()
+
+ @classmethod
+ def get_knowledge_config_from_knowledge_id(cls, session, knowledge_id):
+ knowledge_configs = session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
+ configs = {}
+ for knowledge_config in knowledge_configs:
+ configs[knowledge_config.key] = knowledge_config.value
+ return configs
+
+ @classmethod
+ def delete_knowledge_config(cls, session, knowledge_id):
+ session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).delete()
+ session.commit()
+
+ @classmethod
+ def get_knowledge_config_from_knowledge_id(cls, session, knowledge_id):
+ knowledge_configs = session.query(KnowledgeConfigs).filter(KnowledgeConfigs.knowledge_id == knowledge_id).all()
+ configs = {}
+ for knowledge_config in knowledge_configs:
+ configs[knowledge_config.key] = knowledge_config.value
+ return configs
diff --git a/superagi/models/knowledges.py b/superagi/models/knowledges.py
new file mode 100644
index 000000000..0bf5019f8
--- /dev/null
+++ b/superagi/models/knowledges.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+
+from sqlalchemy import Column, Integer, String
+import requests
+
+# from superagi.models import AgentConfiguration
+from superagi.models.base_model import DBBaseModel
+
+marketplace_url = "https://app.superagi.com/api"
+# marketplace_url = "http://localhost:8001"
+
+class Knowledges(DBBaseModel):
+ """
+ Represents an knowledge entity.
+
+ Attributes:
+ id (int): The unique identifier of the knowledge.
+ name (str): The name of the knowledge.
+ description (str): The description of the knowledge.
+ vector_db_index_id (int): The index associated with the knowledge.
+ is_deleted (int): The flag for deletion/uninstallation of a knowledge.
+ organisation_id (int): The identifier of the associated organisation.
+ """
+
+ __tablename__ = 'knowledges'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String)
+ description = Column(String)
+ vector_db_index_id = Column(Integer)
+ organisation_id = Column(Integer)
+ contributed_by = Column(String)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Knowledge object.
+
+ Returns:
+ str: String representation of the Knowledge.
+
+ """
+ return f"Knowledge(id={self.id}, name='{self.name}', description='{self.description}', " \
+ f"vector_db_index_id={self.vector_db_index_id}), organisation_id={self.organisation_id}, contributed_by={self.contributed_by})"
+
+ @classmethod
+ def fetch_marketplace_list(cls, page):
+ headers = {'Content-Type': 'application/json'}
+ response = requests.get(
+ marketplace_url + f"/knowledges/marketplace/list/{str(page)}",
+ headers=headers, timeout=10)
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return []
+
+ @classmethod
+ def get_knowledge_install_details(cls, session, marketplace_knowledges, organisation):
+ installed_knowledges = session.query(Knowledges).filter(Knowledges.organisation_id == organisation.id).all()
+ for knowledge in marketplace_knowledges:
+ if knowledge["name"] in [installed_knowledge.name for installed_knowledge in installed_knowledges]:
+ knowledge["is_installed"] = True
+ else:
+ knowledge["is_installed"] = False
+ return marketplace_knowledges
+
+ @classmethod
+ def get_organisation_knowledges(cls, session, organisation):
+ knowledges = session.query(Knowledges).filter(Knowledges.organisation_id == organisation.id).all()
+ knowledge_data = []
+ for knowledge in knowledges:
+ data = {
+ "id": knowledge.id,
+ "name": knowledge.name,
+ "contributed_by": knowledge.contributed_by
+ }
+ knowledge_data.append(data)
+ return knowledge_data
+
+ @classmethod
+ def fetch_knowledge_details_marketplace(cls, knowledge_name):
+ headers = {'Content-Type': 'application/json'}
+ response = requests.get(
+ marketplace_url + f"/knowledges/marketplace/details/{knowledge_name}",
+ headers=headers, timeout=10)
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return []
+
+ @classmethod
+ def get_knowledge_from_id(cls, session, knowledge_id):
+ knowledge = session.query(Knowledges).filter(Knowledges.id == knowledge_id).first()
+ return knowledge
+
+ @classmethod
+ def add_update_knowledge(cls, session, knowledge_data):
+ knowledge = session.query(Knowledges).filter(Knowledges.id == knowledge_data["id"], Knowledges.organisation_id == knowledge_data["organisation_id"]).first()
+ if knowledge:
+ knowledge.name = knowledge_data["name"]
+ knowledge.description = knowledge_data["description"]
+ knowledge.vector_db_index_id = knowledge_data["index_id"]
+ else:
+ knowledge = Knowledges(name=knowledge_data["name"], description=knowledge_data["description"], vector_db_index_id=knowledge_data["index_id"], organisation_id=knowledge_data["organisation_id"], contributed_by=knowledge_data["contributed_by"])
+ session.add(knowledge)
+ session.commit()
+ return knowledge
+
+ @classmethod
+ def delete_knowledge(cls, session, knowledge_id):
+ session.query(Knowledges).filter(Knowledges.id == knowledge_id).delete()
+ session.commit()
+
+ @classmethod
+ def delete_knowledge_from_vector_index(cls, session, vector_db_index_id):
+ session.query(Knowledges).filter(Knowledges.vector_db_index_id == vector_db_index_id).delete()
+ session.commit()
\ No newline at end of file
diff --git a/superagi/models/marketplace_stats.py b/superagi/models/marketplace_stats.py
new file mode 100644
index 000000000..144635a2b
--- /dev/null
+++ b/superagi/models/marketplace_stats.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from sqlalchemy import Column, Integer, String
+import requests
+# from superagi.models import AgentConfiguration
+from superagi.models.base_model import DBBaseModel
+
+marketplace_url = "https://app.superagi.com/api"
+# marketplace_url = "http://localhost:8001"
+
+class MarketPlaceStats(DBBaseModel):
+ """
+ Represents an knowledge entity.
+ Attributes:
+ id (int): The unique identifier of the marketplace stats.
+ reference_id (int): The unique identifier of the reference.
+ reference_name (str): The name of the reference used.
+ key (str): The key for the statistical value.
+ value (int): The value for the specified key.
+ """
+
+ __tablename__ = 'marketplace_stats'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ reference_id = Column(Integer)
+ reference_name = Column(String)
+ key = Column(String)
+ value = Column(Integer)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the MarketplaceStats object.
+ """
+ return f"Knowledge(id={self.id}, reference_id='{self.reference_id}', reference_name='{self.reference_name}', " \
+ f"key='{self.key}', value='{self.value}'"
+
+ @classmethod
+ def get_knowledge_installation_number(cls, knowledge_id: int):
+ headers = {'Content-Type': 'application/json'}
+ response = requests.get(
+ marketplace_url + f"/marketplace/knowledge/downloads/{str(knowledge_id)}",
+ headers=headers, timeout=10)
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return []
+
+ @classmethod
+ def update_knowledge_install_number(cls, session, knowledge_id, install_number):
+ knowledge_install_number = session.query(MarketPlaceStats).filter(MarketPlaceStats.reference_id == knowledge_id, MarketPlaceStats.reference_name == "KNOWLEDGE", MarketPlaceStats.key == "download_count").first()
+ if knowledge_install_number is None:
+ knowledge_install_number = MarketPlaceStats(reference_id=knowledge_id, reference_name="KNOWLEDGE", key="download_count", value=str(install_number))
+ session.add(knowledge_install_number)
+ else:
+ knowledge_install_number.value = str(install_number)
+ session.commit()
\ No newline at end of file
diff --git a/superagi/models/vector_db_configs.py b/superagi/models/vector_db_configs.py
new file mode 100644
index 000000000..9affd49bc
--- /dev/null
+++ b/superagi/models/vector_db_configs.py
@@ -0,0 +1,49 @@
+from sqlalchemy import Column, Integer, Text, String
+
+from superagi.models.base_model import DBBaseModel
+
+
+class VectordbConfigs(DBBaseModel):
+ """
+ Vector db related configurations like api_key, environment, and url are stored here
+ Attributes:
+ id (int): The unique identifier of the vector db configuration.
+ vector_db_id (int): The identifier of the associated vector db.
+ key (str): The key of the configuration setting.
+ value (str): The value of the configuration setting.
+ """
+
+ __tablename__ = 'vector_db_configs'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ vector_db_id = Column(Integer)
+ key = Column(String)
+ value = Column(Text)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Agent Configuration object.
+ Returns:
+ str: String representation of the Agent Configuration.
+ """
+ return f"VectorConfiguration(id={self.id}, key={self.key}, value={self.value})"
+
+ @classmethod
+ def get_vector_db_config_from_db_id(cls, session, vector_db_id):
+ vector_db_configs = session.query(VectordbConfigs).filter(VectordbConfigs.vector_db_id == vector_db_id).all()
+ config_data = {}
+ for config in vector_db_configs:
+ config_data[config.key] = config.value
+ return config_data
+
+ @classmethod
+ def add_vector_db_config(cls, session, vector_db_id, db_creds):
+ for key, value in db_creds.items():
+ vector_db_config = VectordbConfigs(vector_db_id=vector_db_id, key=key, value=value)
+ session.add(vector_db_config)
+ session.commit()
+
+ @classmethod
+ def delete_vector_db_configs(cls, session, vector_db_id):
+ session.query(VectordbConfigs).filter(VectordbConfigs.vector_db_id == vector_db_id).delete()
+ session.commit()
\ No newline at end of file
diff --git a/superagi/models/vector_db_indices.py b/superagi/models/vector_db_indices.py
new file mode 100644
index 000000000..5f522973f
--- /dev/null
+++ b/superagi/models/vector_db_indices.py
@@ -0,0 +1,59 @@
+from __future__ import annotations
+
+from sqlalchemy import Column, Integer, String
+
+# from superagi.models import AgentConfiguration
+from superagi.models.base_model import DBBaseModel
+
+
+class VectordbIndices(DBBaseModel):
+ """
+ Represents an vector db index.
+ Attributes:
+ id (int): The unique identifier of the index/collection also referred to as class in Weaviate.
+ name (str): The name of the index/collection.
+ vector_db_id (int): The identifier of the associated vector db.
+ """
+
+ __tablename__ = 'vector_db_indices'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String)
+ vector_db_id = Column(Integer)
+ dimensions = Column(Integer)
+ state = Column(String)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Vector db index object.
+ Returns:
+ str: String representation of the Vector db index.
+ """
+ return f"VectordbIndices(id={self.id}, name='{self.name}', vector_db_id={self.vector_db_id}, dimensions={self.dimensions}, state={self.state})"
+
+ @classmethod
+ def get_vector_index_from_id(cls, session, vector_db_index_id):
+ vector_db_index = session.query(VectordbIndices).filter(VectordbIndices.id == vector_db_index_id).first()
+ return vector_db_index
+
+ @classmethod
+ def get_vector_indices_from_vectordb(cls, session, vector_db_id):
+ vector_indices = session.query(VectordbIndices).filter(VectordbIndices.vector_db_id == vector_db_id).all()
+ return vector_indices
+
+ @classmethod
+ def delete_vector_db_index(cls, session, vector_index_id):
+ session.query(VectordbIndices).filter(VectordbIndices.id == vector_index_id).delete()
+ session.commit()
+
+ @classmethod
+ def add_vector_index(cls, session, index_name, vector_db_id, dimensions, state):
+ vector_index = VectordbIndices(name=index_name, vector_db_id=vector_db_id, dimensions=dimensions, state=state)
+ session.add(vector_index)
+ session.commit()
+
+ @classmethod
+ def update_vector_index_state(cls, session, index_id, state):
+ vector_index = session.query(VectordbIndices).filter(VectordbIndices.id == index_id).first()
+ vector_index.state = state
+ session.commit()
\ No newline at end of file
diff --git a/superagi/models/vector_dbs.py b/superagi/models/vector_dbs.py
new file mode 100644
index 000000000..a7735b90f
--- /dev/null
+++ b/superagi/models/vector_dbs.py
@@ -0,0 +1,68 @@
+from __future__ import annotations
+import requests
+
+from sqlalchemy import Column, Integer, String
+
+# from superagi.models import AgentConfiguration
+from superagi.models.base_model import DBBaseModel
+
+marketplace_url = "https://app.superagi.com/api"
+# marketplace_url = "http://localhost:8001"
+
+class Vectordbs(DBBaseModel):
+ """
+ Represents an vector db entity.
+ Attributes:
+ id (int): The unique identifier of the agent.
+ name (str): The name of the database.
+ db_type (str): The name of the db agent.
+ organisation_id (int): The identifier of the associated organisation.
+ """
+
+ __tablename__ = 'vector_dbs'
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String)
+ db_type = Column(String)
+ organisation_id = Column(Integer)
+
+ def __repr__(self):
+ """
+ Returns a string representation of the Vector db object.
+ Returns:
+ str: String representation of the Vector db.
+ """
+ return f"Vector(id={self.id}, name='{self.name}', db_type='{self.db_type}' organisation_id={self.organisation_id}, updated_at={self.updated_at})"
+
+ @classmethod
+ def get_vector_db_from_id(cls, session, vector_db_id):
+ vector_db = session.query(Vectordbs).filter(Vectordbs.id == vector_db_id).first()
+ return vector_db
+
+ @classmethod
+ def fetch_marketplace_list(cls):
+ headers = {'Content-Type': 'application/json'}
+ response = requests.get(
+ marketplace_url + f"/vector_dbs/marketplace/list",
+ headers=headers, timeout=10)
+ if response.status_code == 200:
+ return response.json()
+ else:
+ return []
+
+ @classmethod
+ def get_vector_db_from_organisation(cls, session, organisation):
+ vector_db_list = session.query(Vectordbs).filter(Vectordbs.organisation_id == organisation.id).all()
+ return vector_db_list
+
+ @classmethod
+ def add_vector_db(cls, session, name, db_type, organisation):
+ vector_db = Vectordbs(name=name, db_type=db_type, organisation_id=organisation.id)
+ session.add(vector_db)
+ session.commit()
+ return vector_db
+
+ @classmethod
+ def delete_vector_db(cls, session, vector_db_id):
+ session.query(Vectordbs).filter(Vectordbs.id == vector_db_id).delete()
+ session.commit()
\ No newline at end of file
diff --git a/superagi/resource_manager/file_manager.py b/superagi/resource_manager/file_manager.py
index 0ccc30950..4c20ba16d 100644
--- a/superagi/resource_manager/file_manager.py
+++ b/superagi/resource_manager/file_manager.py
@@ -1,21 +1,18 @@
import csv
-
from sqlalchemy.orm import Session
-
+from superagi.config.config import get_config
+import os
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
from superagi.types.storage_types import StorageType
-
-
class FileManager:
def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id
-
def write_binary_file(self, file_name: str, data):
if self.agent_id is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name,
@@ -26,7 +23,6 @@ def write_binary_file(self, file_name: str, data):
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
-
try:
with open(final_path, mode="wb") as img:
img.write(data)
@@ -36,7 +32,6 @@ def write_binary_file(self, file_name: str, data):
return f"Binary {file_name} saved successfully"
except Exception as err:
return f"Error write_binary_file: {err}"
-
def write_to_s3(self, file_name, final_path):
with open(final_path, 'rb') as img:
resource = ResourceHelper.make_written_file_resource(file_name=file_name,
@@ -44,14 +39,11 @@ def write_to_s3(self, file_name, final_path):
self.agent_id),
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
- self.agent_execution_id))
- if resource is not None:
- self.session.add(resource)
- self.session.commit()
- self.session.flush()
- if resource.storage_type == StorageType.S3.value:
- s3_helper = S3Helper()
- s3_helper.upload_file(img, path=resource.path)
+ self.agent_execution_id),
+ session=self.session)
+ if resource.storage_type == StorageType.S3.value:
+ s3_helper = S3Helper()
+ s3_helper.upload_file(img, path=resource.path)
def write_file(self, file_name: str, content):
if self.agent_id is not None:
@@ -63,7 +55,6 @@ def write_file(self, file_name: str, content):
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
-
try:
with open(final_path, mode="w") as file:
file.write(content)
@@ -73,7 +64,6 @@ def write_file(self, file_name: str, content):
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_file: {err}"
-
def write_csv_file(self, file_name: str, csv_data):
if self.agent_id is not None:
final_path = ResourceHelper.get_agent_write_resource_path(file_name,
@@ -84,7 +74,6 @@ def write_csv_file(self, file_name: str, csv_data):
self.agent_execution_id))
else:
final_path = ResourceHelper.get_resource_path(file_name)
-
try:
with open(final_path, mode="w", newline="") as file:
writer = csv.writer(file, lineterminator="\n")
@@ -94,6 +83,7 @@ def write_csv_file(self, file_name: str, csv_data):
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_csv_file: {err}"
+
def get_agent_resource_path(self, file_name: str):
return ResourceHelper.get_agent_write_resource_path(file_name, agent=Agent.get_agent_from_id(self.session,
@@ -101,3 +91,34 @@ def get_agent_resource_path(self, file_name: str):
agent_execution=AgentExecution
.get_agent_execution_from_id(self.session,
self.agent_execution_id))
+ def read_file(self, file_name: str):
+ if self.agent_id is not None:
+ final_path = self.get_agent_resource_path(file_name)
+ else:
+ final_path = ResourceHelper.get_resource_path(file_name)
+
+ try:
+ with open(final_path, mode="r") as file:
+ content = file.read()
+ logger.info(f"{file_name} - File read successfully")
+ return content
+ except Exception as err:
+ return f"Error while reading file {file_name}: {err}"
+ def get_files(self):
+ """
+ Gets all file names generated by the CodingTool.
+ Returns:
+ A list of file names.
+ """
+
+ if self.agent_id is not None:
+ final_path = self.get_agent_resource_path("")
+ else:
+ final_path = ResourceHelper.get_resource_path("")
+ try:
+ # List all files in the directory
+ files = os.listdir(final_path)
+ except Exception as err:
+ logger.error(f"Error while accessing files in {final_path}: {err}")
+ files = []
+ return files
diff --git a/superagi/tool_manager.py b/superagi/tool_manager.py
index 674dd0e03..dea5f4bb0 100644
--- a/superagi/tool_manager.py
+++ b/superagi/tool_manager.py
@@ -5,6 +5,7 @@
import zipfile
import json
+
def parse_github_url(github_url):
parts = github_url.split('/')
owner = parts[3]
@@ -12,11 +13,11 @@ def parse_github_url(github_url):
branch = "main"
return f"{owner}/{repo}/{branch}"
+
def download_tool(tool_url, target_folder):
parsed_url = parse_github_url(tool_url)
parts = parsed_url.split("/")
owner, repo, branch, path = parts[0], parts[1], parts[2], "/".join(parts[3:])
-
archive_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/{branch}"
response = requests.get(archive_url)
@@ -31,19 +32,15 @@ def download_tool(tool_url, target_folder):
for member in members:
archive_folder = f"{owner}-{repo}"
target_name = member.replace(f"{archive_folder}/", "", 1)
-
# Skip the unique hash folder while extracting:
segments = target_name.split('/', 1)
if len(segments) > 1:
target_name = segments[1]
else:
continue
-
target_path = os.path.join(target_folder, target_name)
-
if not target_name:
continue
-
if member.endswith('/'):
os.makedirs(target_path, exist_ok=True)
else:
@@ -53,6 +50,54 @@ def download_tool(tool_url, target_folder):
os.remove(tool_zip_file_path)
+def download_marketplace_tool(tool_url, target_folder):
+ parsed_url = tool_url.split("/")
+ owner, repo = parsed_url[3], parsed_url[4]
+ archive_url = f"https://api.github.com/repos/{owner}/{repo}/zipball/main"
+ response = requests.get(archive_url)
+ tool_zip_file_path = os.path.join(target_folder, 'tool.zip')
+
+ with open(tool_zip_file_path, 'wb') as f:
+ f.write(response.content)
+
+ with zipfile.ZipFile(tool_zip_file_path, 'r') as z:
+ for member in z.namelist():
+ archive_folder, target_name = member.split('/', 1)
+ target_name = os.path.join(target_folder, target_name)
+ if member.endswith('/'):
+ os.makedirs(target_name, exist_ok=True)
+ elif not target_name.endswith('.md'):
+ with open(target_name, 'wb') as outfile, z.open(member) as infile:
+ outfile.write(infile.read())
+
+ os.remove(tool_zip_file_path)
+
+
+def get_marketplace_tool_links(repo_url):
+ folder_links = {}
+ api_url = f"https://api.github.com/repos/{repo_url}/contents"
+ response = requests.get(api_url)
+ contents = response.json()
+
+ for content in contents:
+ if content["type"] == "dir":
+ folder_name = content["name"]
+ folder_link = f"https://github.com/{repo_url}/tree/main/{folder_name}"
+ folder_links[folder_name] = folder_link
+
+ return folder_links
+
+
+def update_tools_json(existing_tools_json_path, folder_links):
+ with open(existing_tools_json_path, "r") as file:
+ tools_data = json.load(file)
+ if "tools" not in tools_data:
+ tools_data["tools"] = {}
+ tools_data["tools"].update(folder_links)
+ with open(existing_tools_json_path, "w") as file:
+ json.dump(tools_data, file, indent=4)
+
+
def load_tools_config():
tool_config_path = str(Path(__file__).parent.parent)
with open(tool_config_path + "/tools.json", "r") as f:
@@ -60,15 +105,35 @@ def load_tools_config():
return config["tools"]
+def load_marketplace_tools():
+ marketplace_url = "TransformerOptimus/SuperAGI-Tools"
+ tools_config_path = str(Path(__file__).parent.parent)
+ tools_json_path = tools_config_path + "/tools.json"
+ # Get folder links from the repository
+ marketplace_tool_urls = get_marketplace_tool_links(marketplace_url)
+ # Update existing tools.json file
+ update_tools_json(tools_json_path, marketplace_tool_urls)
+
+
+def is_marketplace_url(url):
+ return url.startswith("https://github.com/TransformerOptimus/SuperAGI-Tools/tree")
+
def download_and_extract_tools():
tools_config = load_tools_config()
for tool_name, tool_url in tools_config.items():
- tool_folder = os.path.join("superagi", "tools", tool_name)
- if not os.path.exists(tool_folder):
- os.makedirs(tool_folder)
- download_tool(tool_url, tool_folder)
+ if is_marketplace_url(tool_url):
+ tool_folder = os.path.join("superagi/tools/marketplace_tools")
+ if not os.path.exists(tool_folder):
+ os.makedirs(tool_folder)
+ download_marketplace_tool(tool_url, tool_folder)
+ else:
+ tool_folder = os.path.join("superagi/tools/external_tools", tool_name)
+ if not os.path.exists(tool_folder):
+ os.makedirs(tool_folder)
+ download_tool(tool_url, tool_folder)
if __name__ == "__main__":
- download_and_extract_tools()
\ No newline at end of file
+ load_marketplace_tools()
+ download_and_extract_tools()
diff --git a/superagi/tools/code/coding_toolkit.py b/superagi/tools/code/coding_toolkit.py
index f96f9e800..248af2536 100644
--- a/superagi/tools/code/coding_toolkit.py
+++ b/superagi/tools/code/coding_toolkit.py
@@ -5,6 +5,7 @@
from superagi.tools.code.write_code import CodingTool
from superagi.tools.code.write_spec import WriteSpecTool
from superagi.tools.code.write_test import WriteTestTool
+from superagi.tools.code.improve_code import ImproveCodeTool
class CodingToolkit(BaseToolkit, ABC):
@@ -12,7 +13,7 @@ class CodingToolkit(BaseToolkit, ABC):
description: str = "Coding Tool kit contains all tools related to coding tasks"
def get_tools(self) -> List[BaseTool]:
- return [CodingTool(), WriteSpecTool(), WriteTestTool()]
+ return [CodingTool(), WriteSpecTool(), WriteTestTool(), ImproveCodeTool()]
def get_env_keys(self) -> List[str]:
return []
diff --git a/superagi/tools/code/improve_code.py b/superagi/tools/code/improve_code.py
new file mode 100644
index 000000000..c54856811
--- /dev/null
+++ b/superagi/tools/code/improve_code.py
@@ -0,0 +1,99 @@
+import re
+from typing import Type, Optional, List
+
+from pydantic import BaseModel, Field
+
+from superagi.agent.agent_prompt_builder import AgentPromptBuilder
+from superagi.helper.prompt_reader import PromptReader
+from superagi.helper.token_counter import TokenCounter
+from superagi.lib.logger import logger
+from superagi.llms.base_llm import BaseLlm
+from superagi.resource_manager.file_manager import FileManager
+from superagi.tools.base_tool import BaseTool
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+
+
+class ImproveCodeSchema(BaseModel):
+ pass
+
+
+class ImproveCodeTool(BaseTool):
+ """
+ Used to improve the already generated code by reading the code from the files
+
+ Attributes:
+ llm: LLM used for code generation.
+ name : The name of the tool.
+ description : The description of the tool.
+ resource_manager: Manages the file resources.
+ """
+ llm: Optional[BaseLlm] = None
+ agent_id: int = None
+ name = "ImproveCodeTool"
+ description = (
+ "This tool improves the generated code."
+ )
+ args_schema: Type[ImproveCodeSchema] = ImproveCodeSchema
+ resource_manager: Optional[FileManager] = None
+ tool_response_manager: Optional[ToolResponseQueryManager] = None
+ goals: List[str] = []
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def _execute(self) -> str:
+ """
+ Execute the improve code tool.
+
+ Returns:
+ Improved code or error message.
+ """
+ # Get all file names that the CodingTool has written
+ file_names = self.resource_manager.get_files()
+ logger.info(file_names)
+ # Loop through each file
+ for file_name in file_names:
+ if '.txt' not in file_name and '.sh' not in file_name and '.json' not in file_name:
+ # Read the file content
+ content = self.resource_manager.read_file(file_name)
+
+ # Generate a prompt from improve_code.txt
+ prompt = PromptReader.read_tools_prompt(__file__, "improve_code.txt")
+
+ # Combine the hint from the file, goals, and content
+ prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals))
+ prompt = prompt.replace("{content}", content)
+
+ # Add the file content to the chat completion prompt
+ prompt = prompt + "\nOriginal Code:\n```\n" + content + "\n```"
+
+
+
+ # Use LLM to generate improved code
+ result = self.llm.chat_completion([{'role': 'system', 'content': prompt}])
+
+ # Extract the response first
+ response = result.get('response')
+ if not response:
+ logger.info("RESPONSE NOT AVAILABLE")
+
+ # Now extract the choices from response
+ choices = response.get('choices')
+ if not choices:
+ logger.info("CHOICES NOT AVAILABLE")
+
+ # Now you can safely extract the message content
+ improved_content = choices[0]["message"]["content"]
+ # improved_content = result["messages"][0]["content"]
+ parsed_content = re.findall("```(?:\w*\n)?(.*?)```", improved_content, re.DOTALL)
+ parsed_content_code = "\n".join(parsed_content)
+
+ # Rewrite the file with the improved content
+ save_result = self.resource_manager.write_file(file_name, parsed_content_code)
+
+ if save_result.startswith("Error"):
+ return save_result
+ else:
+ continue
+
+ return f"All codes improved and saved successfully in: " + " ".join(file_names)
\ No newline at end of file
diff --git a/superagi/tools/code/prompts/improve_code.txt b/superagi/tools/code/prompts/improve_code.txt
new file mode 100644
index 000000000..973525210
--- /dev/null
+++ b/superagi/tools/code/prompts/improve_code.txt
@@ -0,0 +1,26 @@
+You are a super smart developer. You have been tasked with fixing and filling the function and classes where only the description of code is written without the actual code . There might be placeholders in the code you have to fill in.
+You provide fully functioning, well formatted code with few comments, that works and has no bugs.
+If the code is already correct and doesn't need change, just return the same code
+However, make sure that you only return the improved code, without any additional content.
+
+
+Please structure the improved code as follows:
+
+```
+CODE
+```
+
+Please return the full new code in same format as the original code
+Don't write any explanation or description in your response other than the actual code
+
+Your high-level goal is:
+{goals}
+
+The content of the file you need to improve is:
+{content}
+
+Only return the code and not any other line
+
+To start, first analyze the existing code. Check for any function with missing logic inside it and fill the function.
+Make sure, that not a single function is empty or contains just comments, there should be function logic inside it
+Return fully completed functions by filling the placeholders
\ No newline at end of file
diff --git a/superagi/tools/code/prompts/write_code.txt b/superagi/tools/code/prompts/write_code.txt
index 17f8c37ae..4256ab321 100644
--- a/superagi/tools/code/prompts/write_code.txt
+++ b/superagi/tools/code/prompts/write_code.txt
@@ -1,4 +1,5 @@
You are a super smart developer who practices good Development for writing code according to a specification.
+Please note that the code should be fully functional. There should be no placeholder in functions or classes in any file.
Your high-level goal is:
{goals}
@@ -15,15 +16,14 @@ functions, methods you'll use and a quick comment on their purpose.
Then you will output the content of each file including ALL code.
Each file must strictly follow a markdown code block format, where the following tokens must be replaced such that
-[FILENAME] is the lowercase file name including the file extension,
+FILENAME is the lowercase file name including the file extension,
[LANG] is the markup code block language for the code's language, and [CODE] is the code:
-[FILENAME]
+FILENAME
```[LANG]
[CODE]
```
You will start with the "entrypoint" file, then go to the ones that are imported by that file, and so on.
-Please note that the code should be fully functional. No placeholders.
Follow a language and framework appropriate best practice file naming convention.
Make sure that files contain all imports, types etc. Make sure that code in different files are compatible with each other.
diff --git a/superagi/tools/code/prompts/write_test.txt b/superagi/tools/code/prompts/write_test.txt
index 6bcabc8b6..a7c65fba3 100644
--- a/superagi/tools/code/prompts/write_test.txt
+++ b/superagi/tools/code/prompts/write_test.txt
@@ -9,7 +9,7 @@ Test Description:
{spec}
Test should follow the following format:
-[FILENAME]
+FILENAME
```[LANG]
[UNIT_TEST_CODE]
```
diff --git a/superagi/tools/code/write_code.py b/superagi/tools/code/write_code.py
index 9c1a001a9..1eed98c36 100644
--- a/superagi/tools/code/write_code.py
+++ b/superagi/tools/code/write_code.py
@@ -40,7 +40,7 @@ class CodingTool(BaseTool):
"Make sure that every detail of the architecture is, in the end, implemented as code. "
"Think step by step and reason yourself to the right decisions to make sure we get it right. "
"You will first lay out the names of the core classes, functions, methods that will be necessary, "
- "as well as a quick comment on their purpose. Then you will output the content of each file including ALL code."
+ "as well as a quick comment on their purpose. Then you will output the content of each file including each function and class and ALL code."
)
args_schema: Type[CodingSchema] = CodingSchema
goals: List[str] = []
@@ -84,6 +84,8 @@ def _execute(self, code_description: str) -> str:
for match in matches:
# Get the filename
file_name = re.sub(r'[<>"|?*]', "", match.group(1))
+ if not file_name[0].isalnum():
+ file_name = file_name[1:-1]
# Get the code
code = match.group(2)
diff --git a/superagi/tools/file/read_file.py b/superagi/tools/file/read_file.py
index b5e92b6b7..0f9d6c730 100644
--- a/superagi/tools/file/read_file.py
+++ b/superagi/tools/file/read_file.py
@@ -4,10 +4,13 @@
from pydantic import BaseModel, Field
from superagi.helper.resource_helper import ResourceHelper
+from superagi.helper.s3_helper import S3Helper
from superagi.models.agent_execution import AgentExecution
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
from superagi.models.agent import Agent
+from superagi.types.storage_types import StorageType
+from superagi.config.config import get_config
class ReadFileSchema(BaseModel):
@@ -42,13 +45,16 @@ def _execute(self, file_name: str):
The file content and the file name
"""
final_path = ResourceHelper.get_agent_read_resource_path(file_name, agent=Agent.get_agent_from_id(
- session=self.toolkit_config.session, agent_id=self.agent_id),
- agent_execution=AgentExecution.get_agent_execution_from_id(
- session=self.toolkit_config.session,
- agent_execution_id=self.agent_execution_id))
+ session=self.toolkit_config.session, agent_id=self.agent_id), agent_execution=AgentExecution
+ .get_agent_execution_from_id(session=self
+ .toolkit_config.session,
+ agent_execution_id=self
+ .agent_execution_id))
+ if StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) == StorageType.S3:
+ return S3Helper().read_from_s3(final_path)
+
if final_path is None or not os.path.exists(final_path):
raise FileNotFoundError(f"File '{file_name}' not found.")
-
directory = os.path.dirname(final_path)
os.makedirs(directory, exist_ok=True)
diff --git a/superagi/tools/image_generation/dalle_image_gen.py b/superagi/tools/image_generation/dalle_image_gen.py
index 56c071d12..e27255a64 100644
--- a/superagi/tools/image_generation/dalle_image_gen.py
+++ b/superagi/tools/image_generation/dalle_image_gen.py
@@ -51,7 +51,7 @@ def _execute(self, prompt: str, image_names: list, size: int = 512, num: int = 2
if size not in [256, 512, 1024]:
size = min([256, 512, 1024], key=lambda x: abs(x - size))
response = OpenAiDalle(api_key=self.get_tool_config("OPENAI_API_KEY"), number_of_results=num).generate_image(
- prompt, size, num)
+ prompt, size)
response = response.__dict__
response = response['_previous']['data']
for i in range(num):
diff --git a/superagi/tools/image_generation/stable_diffusion_image_gen.py b/superagi/tools/image_generation/stable_diffusion_image_gen.py
index 983d7ec5e..6831f5d98 100644
--- a/superagi/tools/image_generation/stable_diffusion_image_gen.py
+++ b/superagi/tools/image_generation/stable_diffusion_image_gen.py
@@ -5,13 +5,15 @@
import requests
from PIL import Image
from pydantic import BaseModel, Field
-
+from superagi.helper.resource_helper import ResourceHelper
from superagi.resource_manager.file_manager import FileManager
from superagi.tools.base_tool import BaseTool
+from superagi.models.agent_execution import AgentExecution
+from superagi.models.agent import Agent
class StableDiffusionImageGenInput(BaseModel):
- prompt: str = Field(..., description="Prompt for Image Generation to be used by Stable Diffusion.")
+ prompt: str = Field(..., description="Prompt for Image Generation to be used by Stable Diffusion. The prompt should be as descriptive as possible and mention all the details of the image to be generated")
height: int = Field(..., description="Height of the image to be Generated. default height is 512")
width: int = Field(..., description="Width of the image to be Generated. default width is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
@@ -35,6 +37,7 @@ class StableDiffusionImageGenTool(BaseTool):
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None
+ agent_execution_id: int = None
resource_manager: Optional[FileManager] = None
class Config:
@@ -46,7 +49,6 @@ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int
if api_key is None:
return "Error: Missing Stability API key."
-
response = self.call_stable_diffusion(api_key, width, height, num, prompt, steps)
if response.status_code != 200:
@@ -59,6 +61,7 @@ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int
for artifact in artifacts:
base64_strings.append(artifact['base64'])
+ image_paths=[]
for i in range(num):
image_base64 = base64_strings[i]
img_data = base64.b64decode(image_base64)
@@ -69,7 +72,16 @@ def _execute(self, prompt: str, image_names: list, width: int = 512, height: int
self.resource_manager.write_binary_file(image_names[i], img_byte_arr.getvalue())
- return "Images downloaded and saved successfully"
+ for image in image_names:
+ final_path = ResourceHelper.get_agent_read_resource_path(image, agent=Agent.get_agent_from_id(
+ session=self.toolkit_config.session, agent_id=self.agent_id), agent_execution=AgentExecution
+ .get_agent_execution_from_id(session=self
+ .toolkit_config.session,
+ agent_execution_id=self
+ .agent_execution_id))
+ image_paths.append(final_path)
+
+ return f"Images downloaded and saved successfully at the following locations: {image_paths}"
def call_stable_diffusion(self, api_key, width, height, num, prompt, steps):
engine_id = self.get_tool_config("ENGINE_ID")
diff --git a/superagi/tools/instagram_tool/README.MD b/superagi/tools/instagram_tool/README.MD
new file mode 100644
index 000000000..293504c06
--- /dev/null
+++ b/superagi/tools/instagram_tool/README.MD
@@ -0,0 +1,42 @@
+
+
+
+
+# SuperAGI Instagram Tool
+
+The SuperAGI Instagram Tool works with the stable diffusion tool, generates an image & caption based on the goals defined by the user and posts it on their instagram business account.Currently will only work on the webapp
+
+## ⚙️ Installation
+
+### 🛠 **Setting Up of SuperAGI**
+Set up the SuperAGI by following the instructions given (https://github.com/TransformerOptimus/SuperAGI/blob/main/README.MD)
+
+If you've put the correct Google API key and Custom Search Engine ID, you'll be able to use the Google Search Tool as well.
+
+### 🔧 **Instagram tool requirements**
+
+Since the tool uses the official instagram graph API's to post media on user accounts, There are a few requirements:
+
+You will need access to the following:
+
+ 1. An Instagram Business Account or Instagram Creator Account
+ 2. A Facebook Page connected to that account
+ 3. A Facebook Developer account that can perform Tasks on that Page
+ 4. A registered Facebook App with Basic settings configured
+
+Once everything is set up, add the meta user access token (to be generated from facebook developer account), Facebook page ID (can be found on the facebook page connected to the instagram account under 'Page transparency' in 'About' section of the page ) and the stability API key to the correspponding toolkits.
+
+Follow the steps given in the link to set up meta requirements: (https://developers.facebook.com/docs/instagram-api/getting-started)
+Follow the link to generate stability API key: (https://dreamstudio.com/api/)
+
+### 🔧 **Configuring in SuperAGI Dashboard:**
+
+-You can add your meta user access token and facebook ID to the Instagram Toolkit Page and stability API key to the Image Generation Toolkit Page
+
+## Running SuperAGI Instagram Tool
+
+Once everything has been set up just run/schedule an agent with the goal explaining the media to be published and add instagram tool (which will automatically add stable diffusion tool)
+
+## Warning
+
+It is advised to run the instagram tool in restricted mode since it allows you to validate the photos generated. You can schedule agent runs (recurring runs are supported as well). Also, only one photo will be posted to your account in a run. To post multiple photos use recurring runs.
\ No newline at end of file
diff --git a/superagi/tools/instagram_tool/__init__.py b/superagi/tools/instagram_tool/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/superagi/tools/instagram_tool/instagram.py b/superagi/tools/instagram_tool/instagram.py
new file mode 100644
index 000000000..d6a464a9a
--- /dev/null
+++ b/superagi/tools/instagram_tool/instagram.py
@@ -0,0 +1,226 @@
+import json
+import urllib
+import boto3
+import os
+from superagi.config.config import get_config
+from superagi.helper.resource_helper import ResourceHelper
+from typing import Type, Optional
+from pydantic import BaseModel, Field
+from superagi.helper.token_counter import TokenCounter
+from superagi.llms.base_llm import BaseLlm
+from superagi.tools.base_tool import BaseTool
+import os
+import requests
+from superagi.tools.tool_response_query_manager import ToolResponseQueryManager
+import random
+
+class InstagramSchema(BaseModel):
+ photo_description: str = Field(
+ ...,
+ description="description of the photo",
+ )
+
+class InstagramTool(BaseTool):
+ """
+ Instagram tool
+
+ Attributes:
+ name : The name.
+ description : The description.
+ args_schema : The args schema.
+ """
+ llm: Optional[BaseLlm] = None
+ name = "Instagram tool"
+ description = (
+ "A tool for posting an AI generated photo on Instagram"
+ )
+ args_schema: Type[InstagramSchema] = InstagramSchema
+ tool_response_manager: Optional[ToolResponseQueryManager] = None
+ agent_id:int =None
+ class Config:
+ arbitrary_types_allowed = True
+
+ def _execute(self, photo_description: str) -> str:
+ """
+ Execute the Instagram tool.
+
+ Args:
+ photo_description : description of the photo to be posted
+
+ Returns:
+ Image posted successfully message if image has been posted on instagram or error message.
+ """
+ meta_user_access_token = self.get_tool_config("META_USER_ACCESS_TOKEN")
+ facebook_page_id=self.get_tool_config("FACEBOOK_PAGE_ID")
+
+ if meta_user_access_token is None:
+ return "Error: Missing meta user access token."
+
+ if facebook_page_id is None:
+ return "Error: Missing facebook page id."
+ #create caption for the instagram
+ caption=self.create_caption(photo_description)
+
+ #get request for fetching the instagram_business_account_id
+ root_api_url="https://graph.facebook.com/v17.0/"
+ response=self.get_req_insta_id(root_api_url,facebook_page_id,meta_user_access_token)
+
+ if response.status_code != 200:
+ return f"Non-200 response: {str(response.text)}"
+
+ data = response.json()
+ insta_business_account_id=data["instagram_business_account"]["id"]
+ file_path=self.get_file_path_from_image_generation_tool()
+ #handling case where image generation generates multiple images
+ if(file_path=="resources"):
+ return "A photo has already been posted on your instagram account. To post multiple photos use recurring runs."
+
+ image_url,encoded_caption=self.get_img_url_and_encoded_caption(photo_description,file_path)
+ #post request for getting the media container ID
+ response=self.post_media_container_id(root_api_url,insta_business_account_id,image_url,encoded_caption,meta_user_access_token)
+
+ if response.status_code != 200:
+ return f"Non-200 response: {str(response.text)}"
+
+ data = response.json()
+ container_ID=data["id"]
+ #post request to post the media container on instagram account
+ response=self.post_media(root_api_url,insta_business_account_id,container_ID,meta_user_access_token)
+ if response.status_code != 200:
+ return f"Non-200 response: {str(response.text)}"
+ return "Photo posted successfully!"
+
+ def create_caption(self, photo_description: str) -> str:
+ """
+ Create a caption for the instagram post based on the photo description
+
+ Args:
+ photo_description : Description of the photo to be posted
+
+ Returns:
+ Description of the photo to be posted
+ """
+ caption_prompt ="""Generate an instagram post caption for the following text `{photo_description}`
+ Attempt to make it as relevant as possible to the description and should be different and unique everytime. Add relevant emojis and hashtags."""
+
+ caption_prompt = caption_prompt.replace("{photo_description}", str(photo_description))
+
+ messages = [{"role": "system", "content": caption_prompt}]
+ result = self.llm.chat_completion(messages, max_tokens=self.max_token_limit)
+ caption=result["content"]
+
+ encoded_caption=urllib. parse. quote(caption)
+
+ return encoded_caption
+
+ def get_image_from_s3(self,s3,file_path):
+ """
+ Gets the image from the s3 bucket
+
+ Args:
+ s3: S3 client
+ file_path: path of the image file in s3
+
+ Returns
+ The image file from s3
+ """
+
+ response = s3.get_object(Bucket=get_config("BUCKET_NAME"), Key=file_path)
+ content = response["Body"].read()
+
+ return content
+
+ def get_file_path_from_image_generation_tool(self):
+ """
+ Parses the output of the previous tool (Stable diffusion) and returns the path of the image file
+
+ Args:
+
+ Returns:
+ The path of the image file generated by the image generation toolkit
+ """
+
+ last_tool_response = self.tool_response_manager.get_last_response()
+ file_path="resources"+last_tool_response.partition("['")[2].partition("']")[0]
+
+ if ',' in file_path:
+ # Split the string based on the comma and get the first element (substring before the comma)
+ file_path = file_path.split(',')[0].strip()
+ file_path=file_path[:-1]
+
+ return file_path
+
+ def create_s3_client(self):
+ """
+ Creates an s3 client
+
+ Args:
+
+ Returns:
+ The s3 client
+ """
+
+ s3 = boto3.client(
+ 's3',
+ aws_access_key_id=get_config("AWS_ACCESS_KEY_ID"),
+ aws_secret_access_key=get_config("AWS_SECRET_ACCESS_KEY"),
+ )
+
+ return s3
+
+ def get_img_public_url(self,s3,file_path,content):
+ """
+ Puts the image generated by image generation tool in the s3 bucket and returns the public url of the same
+ Args:
+ s3 : S3 bucket
+ file_path: Path of the image file in s3
+ content: Image file
+
+ Returns:
+ The public url of the image put in s3 bucket
+ """
+
+ bucket_name = get_config("INSTAGRAM_TOOL_BUCKET_NAME")
+ object_key=f"instagram_upload_images/{file_path.split('/')[-1]}{random.randint(0, 1000)}"
+ s3.put_object(Bucket=bucket_name, Key=object_key, Body=content)
+
+ image_url = f"https://{bucket_name}.s3.amazonaws.com/{object_key}"
+ return image_url
+
+ def get_img_url_and_encoded_caption(self,photo_description,file_path):
+ #creating an s3 client
+ s3 = self.create_s3_client()
+
+ #fetching the image from the s3 using the file_path
+ content = self.get_image_from_s3(s3,file_path)
+
+ #storing the image in a public bucket and getting the image url
+ image_url = self.get_img_public_url(s3,file_path,content)
+ #encoding the caption with possible emojis and hashtags and removing the starting and ending double quotes
+ encoded_caption=self.create_caption(photo_description)
+
+ return image_url,encoded_caption
+
+ def get_req_insta_id(self,root_api_url,facebook_page_id,meta_user_access_token):
+ url_to_get_acc_id=f"{root_api_url}{facebook_page_id}?fields=instagram_business_account&access_token={meta_user_access_token}"
+ response=requests.get(
+ url_to_get_acc_id
+ )
+
+ return response
+
+ def post_media_container_id(self,root_api_url,insta_business_account_id,image_url,encoded_caption,meta_user_access_token):
+ url_to_create_media_container=f"{root_api_url}{insta_business_account_id}/media?image_url={image_url}&caption={encoded_caption}&access_token={meta_user_access_token}"
+ response = requests.post(
+ url_to_create_media_container
+ )
+
+ return response
+
+ def post_media(self,root_api_url,insta_business_account_id,container_ID,meta_user_access_token):
+ url_to_post_media_container=f"{root_api_url}{insta_business_account_id}/media_publish?creation_id={container_ID}&access_token={meta_user_access_token}"
+ response = requests.post(
+ url_to_post_media_container
+ )
+
+ return response
diff --git a/superagi/tools/instagram_tool/instagram_toolkit.py b/superagi/tools/instagram_tool/instagram_toolkit.py
new file mode 100644
index 000000000..a0b66cefb
--- /dev/null
+++ b/superagi/tools/instagram_tool/instagram_toolkit.py
@@ -0,0 +1,18 @@
+from abc import ABC
+from typing import List
+from superagi.tools.base_tool import BaseTool, BaseToolkit
+from superagi.tools.instagram_tool.instagram import InstagramTool
+
+class InstagramToolkit(BaseToolkit, ABC):
+ name: str = "Instagram Toolkit"
+ description: str = "Toolkit containing tools for posting AI generated photo on Instagram. Posts only one photo in a run "
+
+ def get_tools(self) -> List[BaseTool]:
+ return [InstagramTool()]
+
+ def get_env_keys(self) -> List[str]:
+ return [
+ "META_USER_ACCESS_TOKEN",
+ "FACEBOOK_PAGE_ID"
+ # Add more config keys specific to your project
+ ]
\ No newline at end of file
diff --git a/superagi/tools/knowledge_search/knowledge_search.py b/superagi/tools/knowledge_search/knowledge_search.py
new file mode 100644
index 000000000..45958bf40
--- /dev/null
+++ b/superagi/tools/knowledge_search/knowledge_search.py
@@ -0,0 +1,60 @@
+from superagi.models.agent_config import AgentConfiguration
+
+from superagi.models.knowledges import Knowledges
+from superagi.models.vector_db_indices import VectordbIndices
+from superagi.models.vector_dbs import Vectordbs
+from superagi.models.vector_db_configs import VectordbConfigs
+from superagi.models.toolkit import Toolkit
+from superagi.vector_store.vector_factory import VectorFactory
+from superagi.models.configuration import Configuration
+from superagi.jobs.agent_executor import AgentExecutor
+
+from typing import Any, Type, List
+from pydantic import BaseModel, Field
+
+from superagi.tools.base_tool import BaseTool
+
+# from superagi.tools.file.read_file import ReadFileTool
+
+
+class KnowledgeSearchSchema(BaseModel):
+ query: str = Field(..., description="The query to search required from knowledge search")
+
+
+class KnowledgeSearchTool(BaseTool):
+ name: str = "Knowledge Search"
+ args_schema: Type[BaseModel] = KnowledgeSearchSchema
+ agent_id: int = None
+ description = (
+ "A tool for performing a Knowledge search on knowledge base which might have knowledge of the task you are pursuing."
+ "To find relevant info, use this tool first before using other tools."
+ "If you don't find sufficient info using Knowledge tool, you may use other tools."
+ "If a question is being asked, responding with context from info returned by knowledge tool is prefered."
+ "Input should be a search query."
+ )
+
+ def _execute(self, query: str):
+ session = self.toolkit_config.session
+ toolkit = session.query(Toolkit).filter(Toolkit.id == self.toolkit_config.toolkit_id).first()
+ organisation_id = toolkit.organisation_id
+ knowledge_id = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == self.agent_id, AgentConfiguration.key == "knowledge").first().value
+ knowledge = Knowledges.get_knowledge_from_id(session, knowledge_id)
+ if knowledge is None:
+ return "Selected Knowledge not found"
+ vector_db_index = VectordbIndices.get_vector_index_from_id(session, knowledge.vector_db_index_id)
+ vector_db = Vectordbs.get_vector_db_from_id(session, vector_db_index.vector_db_id)
+ db_creds = VectordbConfigs.get_vector_db_config_from_db_id(session, vector_db.id)
+ model_api_key = Configuration.fetch_configuration(session, organisation_id, "model_api_key")
+ model_source = Configuration.fetch_configuration(session, organisation_id, "model_source")
+ embedding_model = AgentExecutor.get_embedding(model_source, model_api_key)
+ try:
+ if vector_db_index.state == "Custom":
+ filters = {}
+ if vector_db_index.state == "Marketplace":
+ filters = {"knowledge_name": knowledge.name}
+ vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, vector_db_index.name, embedding_model, **db_creds)
+ search_result = vector_db_storage.get_matching_text(query, metadata=filters)
+ return f"Result: \n{search_result['search_res']}"
+ except Exception as err:
+ return f"Error fetching text: {err}"
+
\ No newline at end of file
diff --git a/superagi/tools/knowledge_search/knowledge_search_toolkit.py b/superagi/tools/knowledge_search/knowledge_search_toolkit.py
new file mode 100644
index 000000000..eb227cab2
--- /dev/null
+++ b/superagi/tools/knowledge_search/knowledge_search_toolkit.py
@@ -0,0 +1,15 @@
+from abc import ABC
+from typing import List
+from superagi.tools.base_tool import BaseTool, BaseToolkit
+from superagi.tools.knowledge_search.knowledge_search import KnowledgeSearchTool
+
+
+class KnowledgeSearchToolkit(BaseToolkit, ABC):
+ name: str = "Knowledge Search Toolkit"
+ description: str = "Toolkit containing tools for performing search on the knowledge base."
+
+ def get_tools(self) -> List[BaseTool]:
+ return [KnowledgeSearchTool()]
+
+ def get_env_keys(self) -> List[str]:
+ return []
\ No newline at end of file
diff --git a/superagi/vector_embeddings/__init__.py b/superagi/vector_embeddings/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/superagi/vector_embeddings/base.py b/superagi/vector_embeddings/base.py
new file mode 100644
index 000000000..aa7edfaeb
--- /dev/null
+++ b/superagi/vector_embeddings/base.py
@@ -0,0 +1,12 @@
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Iterable, List, Optional, Tuple
+
+class VectorEmbeddings(ABC):
+
+ @abstractmethod
+ def get_vector_embeddings_from_chunks(
+ self,
+ final_chunks: Any
+ ):
+ """ Returns embeddings for vector dbs from final chunks"""
\ No newline at end of file
diff --git a/superagi/vector_embeddings/pinecone.py b/superagi/vector_embeddings/pinecone.py
new file mode 100644
index 000000000..39a0089de
--- /dev/null
+++ b/superagi/vector_embeddings/pinecone.py
@@ -0,0 +1,16 @@
+from typing import Any
+from superagi.vector_embeddings.base import VectorEmbeddings
+
+class Pinecone(VectorEmbeddings):
+
+ def __init__(self, uuid, embeds, metadata):
+ self.uuid = uuid
+ self.embeds = embeds
+ self.metadata = metadata
+
+ def get_vector_embeddings_from_chunks(self):
+ """ Returns embeddings for vector dbs from final chunks"""
+ result = {}
+ vectors = list(zip(self.uuid, self.embeds, self.metadata))
+ result['vectors'] = vectors
+ return result
\ No newline at end of file
diff --git a/superagi/vector_embeddings/qdrant.py b/superagi/vector_embeddings/qdrant.py
new file mode 100644
index 000000000..8c71e1d5f
--- /dev/null
+++ b/superagi/vector_embeddings/qdrant.py
@@ -0,0 +1,18 @@
+from typing import Any
+from superagi.vector_embeddings.base import VectorEmbeddings
+
+class Qdrant(VectorEmbeddings):
+
+ def __init__(self, uuid, embeds, metadata):
+ self.uuid = uuid
+ self.embeds = embeds
+ self.metadata = metadata
+
+ def get_vector_embeddings_from_chunks(self):
+ """ Returns embeddings for vector dbs from final chunks"""
+ result = {}
+ result['ids'] = self.uuid
+ result['payload'] = self.metadata
+ result['vectors'] = self.embeds
+
+ return result
\ No newline at end of file
diff --git a/superagi/vector_embeddings/vector_embedding_factory.py b/superagi/vector_embeddings/vector_embedding_factory.py
new file mode 100644
index 000000000..827edb79c
--- /dev/null
+++ b/superagi/vector_embeddings/vector_embedding_factory.py
@@ -0,0 +1,43 @@
+
+import pinecone
+from typing import Optional
+from pinecone import UnauthorizedException
+from superagi.vector_embeddings.pinecone import Pinecone
+from superagi.vector_embeddings.qdrant import Qdrant
+from superagi.types.vector_store_types import VectorStoreType
+
+class VectorEmbeddingFactory:
+
+ @classmethod
+ def build_vector_storage(cls, vector_store: VectorStoreType, chunk_json: Optional[dict] = None):
+ """
+ Get the vector embeddings from final chunks.
+ Args:
+ vector_store : The vector store name.
+ Returns:
+ The vector storage object
+ """
+ final_chunks = []
+ uuid = []
+ embeds = []
+ metadata = []
+ vector_store = VectorStoreType.get_vector_store_type(vector_store)
+ if chunk_json is not None:
+ for key in chunk_json.keys():
+ final_chunks.append(chunk_json[key])
+
+ for i in range(0, len(final_chunks)):
+ uuid.append(final_chunks[i]["id"])
+ embeds.append(final_chunks[i]["embeds"])
+ data = {
+ 'text': final_chunks[i]['text'],
+ 'chunk': final_chunks[i]['chunk'],
+ 'knowledge_name': final_chunks[i]['knowledge_name']
+ }
+ metadata.append(data)
+
+ if vector_store == VectorStoreType.PINECONE:
+ return Pinecone(uuid, embeds, metadata)
+
+ if vector_store == VectorStoreType.QDRANT:
+ return Qdrant(uuid, embeds, metadata)
\ No newline at end of file
diff --git a/superagi/vector_store/base.py b/superagi/vector_store/base.py
index 9c58b664b..1f895e78d 100644
--- a/superagi/vector_store/base.py
+++ b/superagi/vector_store/base.py
@@ -15,7 +15,7 @@ def add_texts(
"""Add texts to the vector store."""
@abstractmethod
- def get_matching_text(self, query: str, top_k: int, **kwargs: Any) -> List[Document]:
+ def get_matching_text(self, query: str, top_k: int, metadata: Optional[dict], **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using specified search type."""
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
@@ -24,3 +24,15 @@ def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
texts = [doc.text_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
+
+ @abstractmethod
+ def get_index_stats(self) -> dict:
+ """Returns stats or information of an index"""
+
+ @abstractmethod
+ def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
+ """Add embeddings to the vector store."""
+
+ @abstractmethod
+ def delete_embeddings_from_vector_db(self,ids: List[str]) -> None:
+ """Delete embeddings from the vector store."""
\ No newline at end of file
diff --git a/superagi/vector_store/chromadb.py b/superagi/vector_store/chromadb.py
index 1f678e568..c6ba80b90 100644
--- a/superagi/vector_store/chromadb.py
+++ b/superagi/vector_store/chromadb.py
@@ -99,3 +99,12 @@ def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict]
)
return documents
+
+ def get_index_stats(self) -> dict:
+ pass
+
+ def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
+ pass
+
+ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
+ pass
\ No newline at end of file
diff --git a/superagi/vector_store/document.py b/superagi/vector_store/document.py
index b940b3bd8..063a04810 100644
--- a/superagi/vector_store/document.py
+++ b/superagi/vector_store/document.py
@@ -4,7 +4,7 @@
class Document(BaseModel):
"""Interface for interacting with a document."""
- text_content: str
+ text_content: str = None
metadata: dict = Field(default_factory=dict)
def __init__(self, text_content, *args, **kwargs):
diff --git a/superagi/vector_store/pinecone.py b/superagi/vector_store/pinecone.py
index 7f229165e..0498c2546 100644
--- a/superagi/vector_store/pinecone.py
+++ b/superagi/vector_store/pinecone.py
@@ -20,8 +20,8 @@ class Pinecone(VectorStore):
def __init__(
self,
index: Any,
- embedding_model: BaseEmbedding,
- text_field: str,
+ embedding_model: Optional[Any] = None,
+ text_field: Optional[str] = 'text',
namespace: Optional[str] = '',
):
try:
@@ -73,10 +73,10 @@ def add_texts(
metadata[self.text_field] = text
vectors.append((id, self.embedding_model.get_embedding(text), metadata))
- self.index.upsert(vectors, namespace=namespace, batch_size=batch_size)
+ self.add_embeddings_to_vector_db({"vectors": vectors})
return ids
- def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[Document]:
+ def get_matching_text(self, query: str, top_k: int = 5, metadata: Optional[dict] = None, **kwargs: Any) -> List[Document]:
"""
Return docs most similar to query using specified search type.
@@ -89,18 +89,60 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D
The list of documents most similar to the query
"""
namespace = kwargs.get("namespace", self.namespace)
-
+ filters = {}
+ if metadata is not None:
+ for key in metadata.keys():
+ filters[key] = {"$eq": metadata[key]}
embed_text = self.embedding_model.get_embedding(query)
- res = self.index.query(embed_text, top_k=top_k, namespace=namespace, include_metadata=True)
-
- documents = []
+ res = self.index.query(embed_text, filter=filters, top_k=top_k, namespace=namespace,include_metadata=True)
+ search_res = self._get_search_text(res, query)
- for doc in res['matches']:
- documents.append(
- Document(
- text_content=doc.metadata[self.text_field],
- metadata=doc.metadata,
+ documents = self._build_documents(res)
+ return {"documents": documents, "search_res": search_res}
+
+ def get_index_stats(self) -> dict:
+ """
+ Returns:
+ Stats or Information about an index
+ """
+ index_stats = self.index.describe_index_stats()
+ dimensions = index_stats.dimension
+ vector_count = index_stats.total_vector_count
+ return {"dimensions": dimensions, "vector_count": vector_count}
+
+ def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
+ """Upserts embeddings to the given vector store"""
+ try:
+ self.index.upsert(vectors=embeddings['vectors'])
+ except Exception as err:
+ raise err
+
+ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
+ """Deletes embeddings from the given vector store"""
+ try:
+ self.index.delete(ids=ids)
+ except Exception as err:
+ raise err
+
+ def _build_documents(self, results: List[dict]):
+ try:
+ documents = []
+ for doc in results['matches']:
+ documents.append(
+ Document(
+ text_content=doc['metadata'][self.text_field],
+ metadata=doc['metadata'],
+ )
)
- )
-
- return documents
+ return documents
+ except Exception as err:
+ raise err
+
+ def _get_search_text(self, results: List[dict], query: str):
+ contexts = [item['metadata']['text'] for item in results['matches']]
+ i = 0
+ search_res = f"Query: {query}\n"
+ for context in contexts:
+ search_res += f"Chunk{i}: \n{context}\n"
+ i += 1
+ return contexts
\ No newline at end of file
diff --git a/superagi/vector_store/qdrant.py b/superagi/vector_store/qdrant.py
index 3526b8fcf..e8c196678 100644
--- a/superagi/vector_store/qdrant.py
+++ b/superagi/vector_store/qdrant.py
@@ -17,11 +17,15 @@
MetadataFilter = Union[DictFilter, common_types.Filter]
-def create_qdrant_client(
+def create_qdrant_client(api_key: Optional[str] = None, url: Optional[str] = None, port: Optional[int] = None
) -> QdrantClient:
- qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
- qdrant_port = get_config("QDRANT_PORT") or 6333
- return QdrantClient(host=qdrant_host_name, port=qdrant_port)
+ if api_key is None:
+ qdrant_host_name = get_config("QDRANT_HOST_NAME") or "localhost"
+ qdrant_port = get_config("QDRANT_PORT") or 6333
+ qdrant_client = QdrantClient(host=qdrant_host_name, port=qdrant_port)
+ else:
+ qdrant_client = QdrantClient(api_key=api_key, url=url, port=port)
+ return qdrant_client
class Qdrant(VectorStore):
@@ -35,14 +39,14 @@ class Qdrant(VectorStore):
text_field_payload_key : Name of the field where the corresponding text for point is stored in the collection.
metadata_payload_key : Name of the field where the corresponding metadata for point is stored in the collection.
"""
- TEXT_FIELD_KEY = "text_field"
+ TEXT_FIELD_KEY = "text"
METADATA_KEY = "metadata"
def __init__(
self,
client: QdrantClient,
- embedding_model: Any,
- collection_name: str,
+ embedding_model: Optional[Any] = None,
+ collection_name: str = None,
text_field_payload_key: str = TEXT_FIELD_KEY,
metadata_payload_key: str = METADATA_KEY,
):
@@ -80,37 +84,30 @@ def add_texts(
text_batch = input_texts[i * batch_limit: (i + 1) * batch_limit]
metadata_batch = metadata_list[i * batch_limit: (i + 1) * batch_limit] or None
id_batch = id_list[i * batch_limit: (i + 1) * batch_limit]
-
- self.client.upsert(
- collection_name=self.collection_name,
- points=models.Batch.construct(
- ids=id_batch,
- vectors=self.__get_embeddings(text_batch),
- payloads=self.__build_payloads(
- text_batch,
- metadata_batch,
- self.text_field_payload_key,
- self.metadata_payload_key,
- ),
- ),
+ vectors = self.__get_embeddings(text_batch)
+ payloads = self.__build_payloads(
+ text_batch,
+ metadata_batch,
+ self.text_field_payload_key,
+ self.metadata_payload_key,
)
-
+ self.add_embeddings_to_vector_db({"ids": id_batch, "vectors": vectors, "payloads": payloads})
collected_ids.extend(id_batch)
return collected_ids
-
+
def get_matching_text(
self,
+ text: str = None,
embedding: List[float] = None,
k: int = 4,
- text: str = None,
- filter: Optional[MetadataFilter] = None,
+ metadata: Optional[dict] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
- ) -> List[Tuple[Document, float]]:
+ ) -> Dict:
"""
Return docs most similar to query using specified search type.
@@ -134,7 +131,18 @@ def get_matching_text(
if text is not None:
embedding = self.__get_embeddings(text)[0]
- results = self.client.search(
+ if metadata is not None:
+ filter_conditions = []
+ for key, value in metadata.items():
+ metadata_filter = {}
+ metadata_filter["key"] = key
+ metadata_filter["match"] = {"value": value}
+ filter_conditions.append(metadata_filter)
+ filter = models.Filter(
+ must = filter_conditions
+ )
+ try:
+ results = self.client.search(
collection_name=self.collection_name,
query_vector=embedding,
query_filter=filter,
@@ -146,10 +154,51 @@ def get_matching_text(
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
- )
+ )
+ except Exception as err:
+ raise err
+ search_res = self._get_search_res(results, text)
+ documents = self.__build_documents(results)
- return self.__build_documents(results)
+ return {"documents": documents, "search_res": search_res}
+
+ def get_index_stats(self) -> dict:
+ """
+ Returns:
+ Stats or Information about a collection
+ """
+ collection_info = self.client.get_collection(collection_name=self.collection_name)
+ dimensions = collection_info.config.params.vectors.size
+ vector_count = collection_info.vectors_count
+ return {"dimensions": dimensions, "vector_count": vector_count}
+
+ def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
+ """Upserts embeddings to the given vector store"""
+ try:
+ self.client.upsert(
+ collection_name=self.collection_name,
+ points=models.Batch(
+ ids=embeddings["ids"],
+ vectors=embeddings["vectors"],
+ payloads=embeddings["payload"]
+ ),
+ )
+ except Exception as err:
+ raise err
+
+ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
+ """Deletes embeddings from the given vector store"""
+ try:
+ self.client.delete(
+ collection_name=self.collection_name,
+ points_selector = models.PointIdsList(
+ points = ids
+ ),
+ )
+ except Exception as err:
+ raise err
+
def __get_embeddings(
self,
texts: Iterable[str]
@@ -162,9 +211,9 @@ def __get_embeddings(
query_vectors.append(query_vector)
else:
raise ValueError("Embedding model is not set")
-
+
return query_vectors
-
+
def __build_payloads(
self,
texts: Iterable[str],
@@ -192,7 +241,7 @@ def __build_payloads(
)
return payloads
-
+
def __build_documents(
self,
results: List[Dict]
@@ -203,12 +252,12 @@ def __build_documents(
documents.append(
Document(
text_content=result.payload.get(self.text_field_payload_key),
- metadata=result.payload.get(self.metadata_payload_key) or {},
+ metadata=(result.payload.get(self.metadata_payload_key)) or {},
)
)
return documents
-
+
@classmethod
def create_collection(cls,
client: QdrantClient,
@@ -217,7 +266,7 @@ def create_collection(cls,
):
"""
Create a new collection in Qdrant if it does not exist.
-
+
Args:
client : The Qdrant client.
collection_name: The name of the collection to create.
@@ -227,4 +276,13 @@ def create_collection(cls,
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=size, distance=Distance.COSINE),
- )
\ No newline at end of file
+ )
+
+ def _get_search_res(self, results, text):
+ contexts = [res.payload for res in results]
+ i = 0
+ search_res = f"Query: {text}\n"
+ for context in contexts:
+ search_res += f"Chunk{i}: \n{context['text']}\n"
+ i += 1
+ return search_res
\ No newline at end of file
diff --git a/superagi/vector_store/vector_factory.py b/superagi/vector_store/vector_factory.py
index d330aee6b..4d9dbd157 100644
--- a/superagi/vector_store/vector_factory.py
+++ b/superagi/vector_store/vector_factory.py
@@ -73,4 +73,24 @@ def get_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding
Qdrant.create_collection(client, index_name, len(sample_embedding))
return qdrant.Qdrant(client, embedding_model, index_name)
- raise ValueError(f"Vector store {vector_store} not supported")
\ No newline at end of file
+ raise ValueError(f"Vector store {vector_store} not supported")
+
+ @classmethod
+ def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embedding_model = None, **creds):
+ if isinstance(vector_store, str):
+ vector_store = VectorStoreType.get_vector_store_type(vector_store)
+
+ if vector_store == VectorStoreType.PINECONE:
+ try:
+ pinecone.init(api_key = creds["api_key"], environment = creds["environment"])
+ index = pinecone.Index(index_name)
+ return Pinecone(index, embedding_model)
+ except UnauthorizedException:
+ raise ValueError("PineCone API key not found")
+
+ if vector_store == VectorStoreType.QDRANT:
+ try:
+ client = qdrant.create_qdrant_client(creds["api_key"], creds["url"], creds["port"])
+ return qdrant.Qdrant(client, embedding_model, index_name)
+ except:
+ raise ValueError("Qdrant API key not found")
\ No newline at end of file
diff --git a/superagi/vector_store/weaviate.py b/superagi/vector_store/weaviate.py
index 0b1d92bfa..e9fdc236b 100644
--- a/superagi/vector_store/weaviate.py
+++ b/superagi/vector_store/weaviate.py
@@ -104,3 +104,12 @@ def _get_metadata_fields(self) -> List[str]:
property_names.remove(self.text_field)
return property_names
+
+ def get_index_stats(self) -> dict:
+ pass
+
+ def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
+ pass
+
+ def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
+ pass
\ No newline at end of file
diff --git a/tests/integration_tests/vector_embeddings/test_pinecone.py b/tests/integration_tests/vector_embeddings/test_pinecone.py
new file mode 100644
index 000000000..302cfd2ed
--- /dev/null
+++ b/tests/integration_tests/vector_embeddings/test_pinecone.py
@@ -0,0 +1,27 @@
+import unittest
+from superagi.vector_embeddings.pinecone import Pinecone
+
+
+class TestPinecone(unittest.TestCase):
+
+ def setUp(self):
+ self.uuid = ["id1", "id2"]
+ self.embeds = ["embed1", "embed2"]
+ self.metadata = ["metadata1", "metadata2"]
+ self.pinecone_instance = Pinecone(self.uuid, self.embeds, self.metadata)
+
+ def test_init(self):
+ self.assertEqual(self.pinecone_instance.uuid, self.uuid)
+ self.assertEqual(self.pinecone_instance.embeds, self.embeds)
+ self.assertEqual(self.pinecone_instance.metadata, self.metadata)
+
+ def test_get_vector_embeddings_from_chunks(self):
+ expected = {
+ 'vectors': list(zip(self.uuid, self.embeds, self.metadata))
+ }
+ result = self.pinecone_instance.get_vector_embeddings_from_chunks()
+ self.assertEqual(result, expected)
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/integration_tests/vector_embeddings/test_qdrant.py b/tests/integration_tests/vector_embeddings/test_qdrant.py
new file mode 100644
index 000000000..46da86c64
--- /dev/null
+++ b/tests/integration_tests/vector_embeddings/test_qdrant.py
@@ -0,0 +1,30 @@
+import unittest
+
+from superagi.vector_embeddings.qdrant import Qdrant
+
+class TestQdrant(unittest.TestCase):
+
+ def setUp(self):
+ self.uuid = ['1234', '5678']
+ self.embeds = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ self.metadata = [{'key1': 'value1'}, {'key2': 'value2'}]
+
+ self.qdrant_obj = Qdrant(self.uuid, self.embeds, self.metadata)
+
+ def test_init(self):
+ self.assertEqual(self.qdrant_obj.uuid, self.uuid)
+ self.assertEqual(self.qdrant_obj.embeds, self.embeds)
+ self.assertEqual(self.qdrant_obj.metadata, self.metadata)
+
+ def test_get_vector_embeddings_from_chunks(self):
+ expected = {
+ 'ids': self.uuid,
+ 'payload': self.metadata,
+ 'vectors': self.embeds,
+ }
+ result = self.qdrant_obj.get_vector_embeddings_from_chunks()
+
+ self.assertEqual(result, expected)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/controllers/test_agent_template.py b/tests/unit_tests/controllers/test_agent_template.py
new file mode 100644
index 000000000..050e9e216
--- /dev/null
+++ b/tests/unit_tests/controllers/test_agent_template.py
@@ -0,0 +1,127 @@
+from unittest.mock import patch, MagicMock
+from superagi.models.agent_template import AgentTemplate
+from superagi.models.agent_template_config import AgentTemplateConfig
+from fastapi.testclient import TestClient
+from main import app
+
+client = TestClient(app)
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_success(mock_get_user_org, mock_auth_db, mock_db):
+ # Create a mock agent template
+ mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
+ # mock_agent_goals = AgentTemplateConfig()
+
+ # Create a mock edited agent configuration
+ mock_updated_agent_configs = {
+ "name": "Updated Agent Template",
+ "description": "Updated Description",
+ "agent_configs": {
+ "goal": ["Create a simple pacman game for me.", "Write all files properly."],
+ "instruction": ["write spec","write code","improve the code","write test"],
+ "agent_type": "Don't Maintain Task Queue",
+ "constraints": ["If you are unsure how you previously did something or want to recall past events, thinking about similar events will help you remember.","Ensure the tool and args are as per current plan and reasoning","Exclusively use the tools listed under \"TOOLS\"","REMEMBER to format your response as JSON, using double quotes (\"\") around keys and string values, and commas (,) to separate items in arrays and objects. IMPORTANTLY, to use a JSON object as a string in another JSON object, you need to escape the double quotes."],
+ "tools": ["Read Email", "Send Email", "Write File"],
+ "exit": "No exit criterion",
+ "iteration_interval": 500,
+ "model": "gpt-4",
+ "max_iterations": 25,
+ "permission_type": "God Mode",
+ "LTM_DB": "Pinecone"
+ }
+ }
+
+ # Mocking the user organisation
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
+ mock_db.session.commit.return_value = None
+ mock_db.session.add.return_value = None
+ mock_db.session.flush.return_value = None
+
+ mock_agent_template_config = AgentTemplateConfig(agent_template_id = 1, key="goal", value=["Create a simple pacman game for me.", "Write all files properly."])
+
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
+
+ assert response.status_code == 200
+
+ # Verify changes in the mock agent template
+ assert mock_agent_template.name == "Updated Agent Template"
+ assert mock_agent_template.description == "Updated Description"
+ assert mock_agent_template_config.key == "goal"
+ assert mock_agent_template_config.value == ["Create a simple pacman game for me.", "Write all files properly."]
+
+
+ session_mock.commit.assert_called()
+ session_mock.flush.assert_called()
+
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_failure(mock_get_user_org, mock_auth_db, mock_db):
+ # Setup: The user organisation exists, but the agent template does not exist.
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = None
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json={})
+
+ # Verify: The response status code should be 404, indicating that the agent template was not found.
+ assert response.status_code == 404
+ assert response.json() == {"detail": "Agent Template not found"}
+
+ # Verify: The database commit method should not have been called because the agent template was not found.
+ session_mock.commit.assert_not_called()
+ session_mock.flush.assert_not_called()
+
+
+@patch('superagi.controllers.agent_template.db')
+@patch('superagi.helper.auth.db')
+@patch('superagi.helper.auth.get_user_organisation')
+def test_edit_agent_template_with_new_config_success(mock_get_user_org, mock_auth_db, mock_db):
+ # Create a mock agent template
+ mock_agent_template = AgentTemplate(id=1, name="Test Agent Template", description="Test Description")
+
+ # Create a mock edited agent configuration
+ mock_updated_agent_configs = {
+ "name": "Updated Agent Template",
+ "description": "Updated Description",
+ "agent_configs": {
+ "new_config_key": "New config value" # This is a new config
+ }
+ }
+
+ # Mocking the user organisation
+ mock_get_user_org.return_value = MagicMock(id=1)
+
+ # Create a session mock
+ session_mock = MagicMock()
+ mock_db.session = session_mock
+ mock_db.session.query.return_value.filter.return_value.first.return_value = mock_agent_template
+ mock_db.session.commit.return_value = None
+ mock_db.session.add.return_value = None
+ mock_db.session.flush.return_value = None
+
+ # Call the endpoint
+ response = client.put("agent_templates/update_agent_template/1", json=mock_updated_agent_configs)
+
+ assert response.status_code == 200
+
+ # Verify changes in the mock agent template
+ assert mock_agent_template.name == "Updated Agent Template"
+ assert mock_agent_template.description == "Updated Description"
+
+ session_mock.commit.assert_called()
+ session_mock.flush.assert_called()
\ No newline at end of file
diff --git a/tests/unit_tests/controllers/test_toolkit.py b/tests/unit_tests/controllers/test_toolkit.py
index 87640ba05..0694c7666 100644
--- a/tests/unit_tests/controllers/test_toolkit.py
+++ b/tests/unit_tests/controllers/test_toolkit.py
@@ -1,4 +1,4 @@
-from unittest.mock import patch
+from unittest.mock import patch, call
import pytest
from fastapi.testclient import TestClient
@@ -59,6 +59,44 @@ def mocks():
return user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3
+@pytest.fixture
+def mock_toolkit_details():
+ # Mock toolkit details data for testing
+ toolkit_details = {
+ "name": "toolkit_1",
+ "description": "Test Toolkit",
+ "tool_code_link": "https://example.com/toolkit_1",
+ "show_toolkit": None,
+ "tools": [
+ {
+ "name": "tool_1",
+ "description": "Test Tool 1",
+ "folder_name": "test_folder_1",
+ "class_name": "TestTool1",
+ "file_name": "test_tool_1.py"
+ },
+ {
+ "name": "tool_2",
+ "description": "Test Tool 2",
+ "folder_name": "test_folder_2",
+ "class_name": "TestTool2",
+ "file_name": "test_tool_2.py"
+ }
+ ],
+ "configs": [
+ {
+ "key": "config_key_1",
+ "value": "config_value_1"
+ },
+ {
+ "key": "config_key_2",
+ "value": "config_value_2"
+ }
+ ]
+ }
+ return toolkit_details
+
+
def test_handle_marketplace_operations_list(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tools, toolkit_1, toolkit_2, tool_1, tool_2, tool_3 = mocks
@@ -68,7 +106,6 @@ def test_handle_marketplace_operations_list(mocks):
patch('superagi.controllers.toolkit.db') as mock_db, \
patch('superagi.models.toolkit.Toolkit.fetch_marketplace_list') as mock_fetch_marketplace_list, \
patch('superagi.helper.auth.db') as mock_auth_db:
-
# Set up mock data
mock_db.session.query.return_value.filter.return_value.all.side_effect = [user_toolkits]
mock_fetch_marketplace_list.return_value = [toolkit_1.to_dict(), toolkit_2.to_dict()]
@@ -96,3 +133,29 @@ def test_handle_marketplace_operations_list(mocks):
"is_installed": True
}
]
+
+
+def test_install_toolkit_from_marketplace(mock_toolkit_details):
+ # Mock the database session and query functions
+ with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
+ patch('superagi.models.toolkit.Toolkit.fetch_marketplace_detail') as mock_fetch_marketplace_detail, \
+ patch('superagi.models.toolkit.Toolkit.add_or_update') as mock_add_or_update, \
+ patch('superagi.models.tool.Tool.add_or_update') as mock_tool_add_or_update, \
+ patch('superagi.controllers.toolkit.db') as mock_db, \
+ patch('superagi.helper.auth.db') as mock_auth_db, \
+ patch('superagi.models.tool_config.ToolConfig.add_or_update') as mock_tool_config_add_or_update:
+ # Set up mock data and behavior
+ mock_get_user_org.return_value = Organisation(id=1)
+ mock_fetch_marketplace_detail.return_value = mock_toolkit_details
+ mock_add_or_update.return_value = Toolkit(id=1, name=mock_toolkit_details['name'],
+ description=mock_toolkit_details['description'])
+
+ # Call the function
+ response = client.get("/toolkits/get/install/toolkit_1")
+
+ # Assertions
+ assert response.status_code == 200
+ assert response.json() == {"message": "ToolKit installed successfully"}
+
+ # Verify the function calls
+ mock_fetch_marketplace_detail.assert_called_once_with(search_str="details", toolkit_name="toolkit_1")
diff --git a/tests/unit_tests/helper/test_resource_helper.py b/tests/unit_tests/helper/test_resource_helper.py
index 4cc9a5f65..d1e9eb42a 100644
--- a/tests/unit_tests/helper/test_resource_helper.py
+++ b/tests/unit_tests/helper/test_resource_helper.py
@@ -1,8 +1,9 @@
-from unittest.mock import patch
+from unittest.mock import patch, MagicMock
from superagi.helper.resource_helper import ResourceHelper
from superagi.models.agent import Agent
from superagi.models.agent_execution import AgentExecution
+from superagi.models.resource import Resource
def test_make_written_file_resource(mocker):
@@ -10,12 +11,28 @@ def test_make_written_file_resource(mocker):
mocker.patch('os.makedirs', return_value=None)
mocker.patch('os.path.getsize', return_value=1000)
mocker.patch('os.path.splitext', return_value=("", ".txt"))
- mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['FILE','/','/','FILE'])
+ mocker.patch('superagi.helper.resource_helper.get_config', side_effect=['FILE', '/', '/', 'FILE'])
mock_agent = Agent(id=1, name='TestAgent')
mock_agent_execution = AgentExecution(id=1, name='TestExecution')
+ session = MagicMock()
with patch('superagi.helper.resource_helper.logger') as logger_mock:
- result = ResourceHelper.make_written_file_resource('test.txt', mock_agent, mock_agent_execution)
+ session.query.return_value.filter_by.return_value.first.return_value = None
+ # Create a Resource object
+ resource = Resource(
+ name='test.txt',
+ path='/test.txt',
+ storage_type='FILE',
+ size=1000,
+ type='application/txt',
+ channel='OUTPUT',
+ agent_id=1,
+ agent_execution_id=1
+ )
+
+ # Mock the session.add() method to return the created Resource object
+ session.add.return_value = resource
+ result = ResourceHelper.make_written_file_resource('test.txt', mock_agent, mock_agent_execution, session)
assert result.name == 'test.txt'
assert result.path == '/test.txt'
diff --git a/tests/unit_tests/helper/test_s3_helper.py b/tests/unit_tests/helper/test_s3_helper.py
new file mode 100644
index 000000000..82540a4d6
--- /dev/null
+++ b/tests/unit_tests/helper/test_s3_helper.py
@@ -0,0 +1,77 @@
+import json
+import pytest
+from unittest.mock import MagicMock, patch
+from botocore.exceptions import NoCredentialsError
+from fastapi import HTTPException
+from superagi.helper.s3_helper import S3Helper
+
+@pytest.fixture()
+def s3helper_object():
+ return S3Helper()
+
+def test__get_s3_client(s3helper_object):
+ with patch('superagi.helper.s3_helper.get_config', return_value='test') as mock_get_config:
+ s3_client = s3helper_object._S3Helper__get_s3_client()
+ mock_get_config.assert_any_call('AWS_ACCESS_KEY_ID')
+ mock_get_config.assert_any_call('AWS_SECRET_ACCESS_KEY')
+
+@pytest.mark.parametrize('have_creds, raises', [(True, False), (False, True)])
+def test_upload_file(s3helper_object, have_creds, raises):
+ s3helper_object.s3.upload_fileobj = MagicMock()
+ s3helper_object.s3.upload_fileobj.side_effect = NoCredentialsError() if not have_creds else None
+
+ if raises:
+ with pytest.raises(HTTPException):
+ s3helper_object.upload_file('file', 'path')
+ else:
+ s3helper_object.upload_file('file', 'path')
+
+@pytest.mark.parametrize('have_creds, raises', [(True, False), (False, True)])
+def test_get_json_file(s3helper_object, have_creds, raises):
+
+ # Mock 'get_object' method from s3 client
+ s3helper_object.s3.get_object = MagicMock()
+
+ # Mocked JSON contents with their 'Body' key as per real response
+ mock_json_file = { 'Body': MagicMock() }
+ mock_json_file['Body'].read = MagicMock(return_value=bytes(json.dumps("content_of_json"), 'utf-8'))
+
+ # Case when we do have credentials but 'get_object' raises an error
+ if not raises:
+ s3helper_object.s3.get_object.return_value = mock_json_file
+ else:
+ s3helper_object.s3.get_object.side_effect = NoCredentialsError()
+
+ # Mocking a path to the file
+ mock_path = "mock_path"
+
+ if raises:
+ with pytest.raises(HTTPException):
+ s3helper_object.get_json_file(mock_path)
+ else:
+ content = s3helper_object.get_json_file(mock_path)
+
+ # Assert that 'get_object' was called with our mocked path
+ s3helper_object.s3.get_object.assert_called_with(Bucket=s3helper_object.bucket_name, Key=mock_path)
+
+ assert content == "content_of_json" # Assert we got our mocked JSON content back
+
+def test_check_file_exists_in_s3(s3helper_object):
+ s3helper_object.s3.list_objects_v2 = MagicMock(return_value={})
+ assert s3helper_object.check_file_exists_in_s3('path') == False
+
+ s3helper_object.s3.list_objects_v2 = MagicMock(return_value={'Contents':[]})
+ assert s3helper_object.check_file_exists_in_s3('path') == True
+
+@pytest.mark.parametrize('http_status, expected_result, raises', [(200, 'file_content', False), (500, None, True)])
+def test_read_from_s3(s3helper_object, http_status, expected_result, raises):
+ s3helper_object.s3.get_object = MagicMock(
+ return_value={'ResponseMetadata': {'HTTPStatusCode': http_status},
+ 'Body': MagicMock(read=lambda: bytes(expected_result, 'utf-8'))}
+ )
+
+ if raises:
+ with pytest.raises(Exception):
+ s3helper_object.read_from_s3('path')
+ else:
+ assert s3helper_object.read_from_s3('path') == expected_result
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_knowledge_configs.py b/tests/unit_tests/models/test_knowledge_configs.py
new file mode 100644
index 000000000..df77a86b0
--- /dev/null
+++ b/tests/unit_tests/models/test_knowledge_configs.py
@@ -0,0 +1,47 @@
+import unittest
+from unittest.mock import Mock, patch, MagicMock
+from sqlalchemy.orm.session import Session
+
+from superagi.models.knowledge_configs import KnowledgeConfigs
+
+class TestKnowledgeConfigs(unittest.TestCase):
+
+ def setUp(self):
+ self.session = Mock(spec=Session)
+ self.knowledge_id = 1
+ self.test_configs = {'key1': 'value1', 'key2': 'value2'}
+
+ @patch('requests.get')
+ def test_fetch_knowledge_config_details_marketplace(self, mock_get):
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = [{'key': 'key1', 'value': 'value1'}, {'key': 'key2', 'value': 'value2'}]
+ mock_get.return_value = mock_response
+
+ configs = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(self.knowledge_id)
+ self.assertEqual(configs, self.test_configs)
+
+ def test_add_update_knowledge_config(self):
+ KnowledgeConfigs.add_update_knowledge_config(self.session, self.knowledge_id, self.test_configs)
+ self.session.add.assert_called()
+ self.session.commit.assert_called()
+
+ def test_get_knowledge_config_from_knowledge_id(self):
+ test_obj = Mock()
+ test_obj.key = "key1"
+ test_obj.value = "value1"
+ self.session.query.return_value.filter.return_value.all.return_value = [test_obj]
+ configs = KnowledgeConfigs.get_knowledge_config_from_knowledge_id(self.session, self.knowledge_id)
+ self.assertEqual(configs, {"key1": "value1"})
+
+ def test_delete_knowledge_config(self):
+ KnowledgeConfigs.delete_knowledge_config(self.session, self.knowledge_id)
+ self.session.query.assert_called()
+ self.session.commit.assert_called()
+
+ def tearDown(self):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_marketplace_stats.py b/tests/unit_tests/models/test_marketplace_stats.py
new file mode 100644
index 000000000..a70a66b7f
--- /dev/null
+++ b/tests/unit_tests/models/test_marketplace_stats.py
@@ -0,0 +1,44 @@
+import unittest
+from unittest.mock import patch, MagicMock
+from sqlalchemy.orm import Session
+from superagi.models.marketplace_stats import MarketPlaceStats
+
+class TestMarketPlaceStats(unittest.TestCase):
+
+ @patch('requests.get')
+ def test_get_knowledge_installation_number(self, mock_get):
+ test_json = {'download_count':123}
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = test_json
+ mock_get.return_value = mock_response
+
+ result = MarketPlaceStats.get_knowledge_installation_number(1)
+ self.assertEqual(result, test_json)
+
+ @patch('requests.get')
+ def test_get_knowledge_installation_number_status_not_200(self, mock_get):
+ mock_response = MagicMock()
+ mock_response.status_code = 404
+ mock_get.return_value = mock_response
+
+ result = MarketPlaceStats.get_knowledge_installation_number(1)
+ self.assertEqual(result, [])
+
+ @patch('sqlalchemy.orm.Session')
+ def test_update_knowledge_install_number_existing(self, mock_session):
+ instance = MagicMock()
+ instance.value = '5'
+ mock_query = MagicMock()
+ mock_query.filter.return_value.first.return_value = instance
+ mock_session.query.return_value = mock_query
+
+ MarketPlaceStats.update_knowledge_install_number(mock_session, 1, 10)
+
+ self.assertEqual(instance.value, "10")
+
+ mock_query.filter.assert_called()
+ mock_session.commit.assert_called()
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/unit_tests/models/test_vector_db_configs.py b/tests/unit_tests/models/test_vector_db_configs.py
new file mode 100644
index 000000000..3bb345cd2
--- /dev/null
+++ b/tests/unit_tests/models/test_vector_db_configs.py
@@ -0,0 +1,33 @@
+import unittest
+from unittest.mock import Mock, patch
+from superagi.models.vector_db_configs import VectordbConfigs
+
+class TestVectordbConfigs(unittest.TestCase):
+ def setUp(self):
+ self.session_mock = Mock()
+ self.vector_db_id_mock = 1
+ self.db_creds_mock = {"key1": "value1", "key2": "value2"}
+
+ @patch('superagi.models.vector_db_configs.VectordbConfigs')
+ def test_get_vector_db_config_from_db_id(self, model_mock):
+ vectordb_mock = Mock()
+ vectordb_mock.key = "key1"
+ vectordb_mock.value = "value1"
+ self.session_mock.query().filter().all.return_value = [vectordb_mock]
+ result = VectordbConfigs.get_vector_db_config_from_db_id(self.session_mock, self.vector_db_id_mock)
+ self.assertEqual(result, {"key1": "value1"})
+
+ @patch('superagi.models.vector_db_configs.VectordbConfigs')
+ def test_add_vector_db_config(self, model_mock):
+ VectordbConfigs.add_vector_db_config(self.session_mock, self.vector_db_id_mock, self.db_creds_mock)
+ self.assertEqual(self.session_mock.add.call_count, len(self.db_creds_mock))
+ self.assertTrue(self.session_mock.commit.called)
+
+ @patch('superagi.models.vector_db_configs.VectordbConfigs')
+ def test_delete_vector_db_configs(self, model_mock):
+ VectordbConfigs.delete_vector_db_configs(self.session_mock, self.vector_db_id_mock)
+ self.assertTrue(self.session_mock.query(model_mock).filter(model_mock.vector_db_id == self.vector_db_id_mock).delete.called)
+ self.assertTrue(self.session_mock.commit.called)
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_vector_db_indices.py b/tests/unit_tests/models/test_vector_db_indices.py
new file mode 100644
index 000000000..162d91b94
--- /dev/null
+++ b/tests/unit_tests/models/test_vector_db_indices.py
@@ -0,0 +1,39 @@
+import unittest
+from unittest.mock import Mock, MagicMock, call
+from superagi.models.vector_db_indices import VectordbIndices
+
+class TestVectordbIndices(unittest.TestCase):
+ def setUp(self):
+ self.mock_session = Mock()
+ self.query_mock = self.mock_session.query.return_value
+ self.filter_mock = self.query_mock.filter.return_value
+
+ def test_get_vector_index_from_id(self):
+ VectordbIndices.get_vector_index_from_id(self.mock_session, 1)
+ self.mock_session.query.assert_called_with(VectordbIndices)
+ self.filter_mock.first.assert_called_once()
+
+ def test_get_vector_indices_from_vectordb(self):
+ VectordbIndices.get_vector_indices_from_vectordb(self.mock_session, 1)
+ self.mock_session.query.assert_called_with(VectordbIndices)
+ self.filter_mock.all.assert_called_once()
+
+ def test_delete_vector_db_index(self):
+ VectordbIndices.delete_vector_db_index(self.mock_session, 1)
+ self.mock_session.query.assert_called_with(VectordbIndices)
+ self.filter_mock.delete.assert_called_once()
+ self.mock_session.commit.assert_called_once()
+
+ def test_add_vector_index(self):
+ VectordbIndices.add_vector_index(self.mock_session, 'test', 1, 100, 'active')
+ self.mock_session.add.assert_called_once()
+ self.mock_session.commit.assert_called_once()
+
+ def test_update_vector_index_state(self):
+ VectordbIndices.update_vector_index_state(self.mock_session, 1, 'inactive')
+ self.mock_session.query.assert_called_with(VectordbIndices)
+ self.filter_mock.first.assert_called_once()
+ self.mock_session.commit.assert_called_once()
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_vector_dbs.py b/tests/unit_tests/models/test_vector_dbs.py
new file mode 100644
index 000000000..ff8a6b0d0
--- /dev/null
+++ b/tests/unit_tests/models/test_vector_dbs.py
@@ -0,0 +1,46 @@
+import unittest
+from unittest.mock import Mock, patch
+from superagi.models.vector_dbs import Vectordbs
+
+class TestVectordbs(unittest.TestCase):
+ def setUp(self):
+ # Create a mock sql session
+ self.mock_session = Mock()
+ # Create an object of Vectordbs for testing
+ self.test_vector_db = Vectordbs(name='test_db', db_type='test_db_type', organisation_id=1)
+
+ @patch('requests.get')
+ def test_fetch_marketplace_list(self, mock_get):
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = [{'name': 'test_db'}]
+ mock_get.return_value = mock_response
+
+ # Assert that fetch_marketplace_list() returns the correct value
+ self.assertListEqual(Vectordbs.fetch_marketplace_list(), [{'name': 'test_db'}])
+
+ def test_get_vector_db_from_id(self):
+ self.mock_session.query.return_value.filter.return_value.first.return_value = self.test_vector_db
+ returned_db = Vectordbs.get_vector_db_from_id(self.mock_session, 1)
+ # Assert that the returned db is the same as the set up test_vector_db
+ self.assertEqual(returned_db, self.test_vector_db)
+
+ def test_get_vector_db_from_organisation(self):
+ self.mock_session.query.return_value.filter.return_value.all.return_value = [self.test_vector_db]
+ returned_db_list = Vectordbs.get_vector_db_from_organisation(self.mock_session, Mock(id=1))
+ # Assert that returned list of dbs contains the test_vector_db
+ self.assertIn(self.test_vector_db, returned_db_list)
+
+ def test_add_vector_db(self):
+ # Assert that new db name matches the created db
+ new_db = Vectordbs.add_vector_db(self.mock_session, 'test_db', 'test_db_type', Mock(id=1))
+ self.assertEqual(new_db.name, 'test_db')
+
+ def test_delete_vector_db(self):
+ Vectordbs.delete_vector_db(self.mock_session, 1)
+ # Assert that the session's delete method was called with the correct arguments
+ self.mock_session.query.assert_called_once_with(Vectordbs)
+ self.mock_session.query.return_value.filter.return_value.delete.assert_called_once()
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/resource_manager/test_resource_manager.py b/tests/unit_tests/resource_manager/test_file_manager.py
similarity index 100%
rename from tests/unit_tests/resource_manager/test_resource_manager.py
rename to tests/unit_tests/resource_manager/test_file_manager.py
diff --git a/tests/unit_tests/test_tool_manager.py b/tests/unit_tests/test_tool_manager.py
index e59a874ac..064c21642 100644
--- a/tests/unit_tests/test_tool_manager.py
+++ b/tests/unit_tests/test_tool_manager.py
@@ -1,22 +1,36 @@
+import json
import os
import shutil
import tempfile
+from unittest.mock import Mock, patch
import pytest
-from unittest.mock import Mock, patch, mock_open, MagicMock
-from superagi.tool_manager import parse_github_url, download_tool, load_tools_config, download_and_extract_tools
+
+from superagi.tool_manager import parse_github_url, download_tool, load_tools_config, download_and_extract_tools, \
+ update_tools_json
+
+
+@pytest.fixture
+def tools_json_path():
+ # Create a temporary directory and return the path to the tools.json file
+ with tempfile.TemporaryDirectory() as temp_dir:
+ yield os.path.join(temp_dir, "tools.json")
+
def test_parse_github_url():
url = 'https://github.com/owner/repo'
assert parse_github_url(url) == 'owner/repo/main'
+
def setup_function():
os.makedirs('target_folder', exist_ok=True)
+
# Teardown function to remove the directory
def teardown_function():
shutil.rmtree('target_folder')
+
@patch('requests.get')
@patch('zipfile.ZipFile')
def test_download_tool(mock_zip, mock_get):
@@ -31,7 +45,6 @@ def test_download_tool(mock_zip, mock_get):
mock_zip.assert_called_once_with('target_folder/tool.zip', 'r')
-
@patch('json.load')
def test_load_tools_config(mock_json_load):
mock_json_load.return_value = {"tools": {"tool1": "url1", "tool2": "url2"}}
@@ -47,5 +60,41 @@ def test_download_and_extract_tools(mock_load_tools_config, mock_download_tool):
download_and_extract_tools()
mock_load_tools_config.assert_called_once()
- mock_download_tool.assert_any_call('url1', os.path.join('superagi', 'tools', 'tool1'))
- mock_download_tool.assert_any_call('url2', os.path.join('superagi', 'tools', 'tool2'))
+ mock_download_tool.assert_any_call('url1', os.path.join('superagi', 'tools', 'external_tools', 'tool1'))
+ mock_download_tool.assert_any_call('url2', os.path.join('superagi', 'tools', 'external_tools', 'tool2'))
+
+
+def test_update_tools_json(tools_json_path):
+ # Create an initial tools.json file with some data
+ initial_data = {
+ "tools": {
+ "tool1": "link1",
+ "tool2": "link2"
+ }
+ }
+ with open(tools_json_path, "w") as file:
+ json.dump(initial_data, file)
+
+ # Define the folder links to be updated
+ folder_links = {
+ "tool3": "link3",
+ "tool4": "link4"
+ }
+
+ # Call the function to update the tools.json file
+ update_tools_json(tools_json_path, folder_links)
+
+ # Read the updated tools.json file
+ with open(tools_json_path, "r") as file:
+ updated_data = json.load(file)
+
+ # Assert that the data was updated correctly
+ expected_data = {
+ "tools": {
+ "tool1": "link1",
+ "tool2": "link2",
+ "tool3": "link3",
+ "tool4": "link4"
+ }
+ }
+ assert updated_data == expected_data
diff --git a/tests/unit_tests/tools/code/test_improve_code.py b/tests/unit_tests/tools/code/test_improve_code.py
new file mode 100644
index 000000000..8deae28c5
--- /dev/null
+++ b/tests/unit_tests/tools/code/test_improve_code.py
@@ -0,0 +1,52 @@
+import pytest
+from unittest.mock import Mock, MagicMock
+from superagi.tools.code.improve_code import ImproveCodeTool
+
+@pytest.fixture
+def mock_improve_code_tool():
+ improve_code_tool = ImproveCodeTool()
+ improve_code_tool.resource_manager = Mock()
+ improve_code_tool.llm = Mock()
+ return improve_code_tool
+
+def test_execute(mock_improve_code_tool):
+ mock_improve_code_tool.resource_manager.get_files.return_value = ['test1', 'test2']
+ mock_improve_code_tool.resource_manager.read_file.return_value = "test file content"
+ mock_improve_code_tool.llm.chat_completion.return_value = {
+ "response":
+ {
+ "choices":
+ [
+ {
+ "message":
+ {
+ "content": "```\nimproved code\n```"
+ }
+ }
+ ]
+ }
+ }
+ mock_improve_code_tool.resource_manager.write_file.return_value = "file saved successfully"
+
+ assert mock_improve_code_tool._execute() == "All codes improved and saved successfully in: test1 test2"
+
+def test_execute_with_error(mock_improve_code_tool):
+ mock_improve_code_tool.resource_manager.get_files.return_value = ['test1']
+ mock_improve_code_tool.resource_manager.read_file.return_value = "test file content"
+ mock_improve_code_tool.llm.chat_completion.return_value = {
+ "response":
+ {
+ "choices":
+ [
+ {
+ "message":
+ {
+ "content": "```\nimproved code\n```"
+ }
+ }
+ ]
+ }
+ }
+ mock_improve_code_tool.resource_manager.write_file.return_value = "Error: Could not save file"
+
+ assert mock_improve_code_tool._execute() == "Error: Could not save file"
diff --git a/tests/unit_tests/tools/image_generation/test_dalle_image_gen.py b/tests/unit_tests/tools/image_generation/test_dalle_image_gen.py
index 54ed5fee3..de2399015 100644
--- a/tests/unit_tests/tools/image_generation/test_dalle_image_gen.py
+++ b/tests/unit_tests/tools/image_generation/test_dalle_image_gen.py
@@ -29,6 +29,6 @@ def test_execute_dalle_image_gen_tool(mock_requests, mock_dalle):
# Assert
assert result == "Images downloaded successfully"
mock_dalle.assert_called_once_with(api_key="test_api_key", number_of_results=num)
- mock_dalle_instance.generate_image.assert_called_once_with(prompt, size, num)
+ mock_dalle_instance.generate_image.assert_called_once_with(prompt, size)
tool.resource_manager.write_binary_file.assert_any_call("image1.png", b"test_image_data")
tool.resource_manager.write_binary_file.assert_any_call("image2.png", b"test_image_data")
\ No newline at end of file
diff --git a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
index aa48b70b2..3f34cec6f 100644
--- a/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
+++ b/tests/unit_tests/tools/image_generation/test_stable_diffusion_image_gen.py
@@ -27,7 +27,13 @@ def create_sample_image_base64():
def stable_diffusion_tool():
with patch('superagi.tools.image_generation.stable_diffusion_image_gen.requests.post') as post_mock, \
patch(
- 'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock:
+ 'superagi.tools.image_generation.stable_diffusion_image_gen.FileManager') as resource_manager_mock, \
+ patch(
+ 'superagi.tools.image_generation.stable_diffusion_image_gen.ResourceHelper') as resource_helper_mock, \
+ patch(
+ 'superagi.tools.image_generation.stable_diffusion_image_gen.Agent') as agent_mock, \
+ patch(
+ 'superagi.tools.image_generation.stable_diffusion_image_gen.AgentExecution') as agent_execution_mock:
# Create a mock response object
response_mock = Mock()
@@ -39,17 +45,23 @@ def stable_diffusion_tool():
resource_manager_mock.write_binary_file.return_value = None
+ # Mock Agent and AgentExecution to return dummy values
+ agent_mock.get_agent_from_id.return_value = Mock()
+ agent_execution_mock.get_agent_execution_from_id.return_value = Mock()
+
yield
+
def test_execute(stable_diffusion_tool):
tool = StableDiffusionImageGenTool()
tool.resource_manager = Mock()
- tool.toolkit_config.get_tool_config = mock_get_tool_config
-
-
- result = tool._execute('prompt', ['img1.png', 'img2.png'])
-
- assert result == 'Images downloaded and saved successfully'
+ tool.agent_id = 123 # Use a dummy agent_id for testing purposes
+ tool.toolkit_config.get_tool_config = lambda key: 'fake_api_key' if key == 'STABILITY_API_KEY' else 'engine_id_1'
+ prompt = 'Test prompt'
+ image_names = ['img1.png', 'img2.png']
+ expected_result = 'Images downloaded and saved successfully'
+ result = tool._execute(prompt, image_names)
+ assert result.startswith(expected_result)
tool.resource_manager.write_binary_file.assert_called()
def test_call_stable_diffusion(stable_diffusion_tool):
diff --git a/tests/unit_tests/tools/instagram_tool/__init__.py b/tests/unit_tests/tools/instagram_tool/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit_tests/tools/instagram_tool/test_instagram_tool.py b/tests/unit_tests/tools/instagram_tool/test_instagram_tool.py
new file mode 100644
index 000000000..cfeeacfc9
--- /dev/null
+++ b/tests/unit_tests/tools/instagram_tool/test_instagram_tool.py
@@ -0,0 +1,97 @@
+import pytest
+from superagi.tools.instagram_tool.instagram import InstagramTool
+from unittest.mock import MagicMock, patch
+
+# Create a fixture for the InstagramTool instance
+@pytest.fixture
+def instagram_tool():
+ return InstagramTool()
+
+def test_execute_missing_meta_user_access_token(instagram_tool):
+ # Test for the case when META_USER_ACCESS_TOKEN is missing
+
+ # Mock the get_tool_config method to return None for META_USER_ACCESS_TOKEN
+ instagram_tool.toolkit_config.get_tool_config = MagicMock(return_value=None)
+
+ # Call the _execute method
+ result = instagram_tool._execute("A beautiful sunset")
+
+ # Verify the output
+ assert result == "Error: Missing meta user access token."
+
+def test_execute_missing_facebook_page_id(instagram_tool):
+ # Test for the case when FACEBOOK_PAGE_ID is missing
+
+ # Mock the get_tool_config method to return None for FACEBOOK_PAGE_ID
+ instagram_tool.toolkit_config.get_tool_config = MagicMock(side_effect=lambda key: "your_meta_user_access_token" if key == "META_USER_ACCESS_TOKEN" else None)
+
+ # Call the _execute method
+ result = instagram_tool._execute("A beautiful sunset")
+
+ # Verify the output
+ assert result == "Error: Missing facebook page id."
+
+def test_get_file_path_from_image_generation_tool(instagram_tool):
+ # Test for the get_file_path_from_image_generation_tool method
+ # Mock the tool_response_manager to return a response
+ instagram_tool.tool_response_manager = MagicMock()
+ instagram_tool.tool_response_manager.get_last_response.return_value = "['/path/to/image.jpg']"
+ file_path = instagram_tool.get_file_path_from_image_generation_tool()
+ assert file_path == "resources/path/to/image.jpg"
+
+
+def test_get_img_public_url(instagram_tool):
+ # Test for the get_img_public_url method
+ # Mock the S3 client and its put_object method
+ s3_client_mock = MagicMock()
+ s3_client_mock.get_object.return_value = {"Body": MagicMock(read=lambda: b"image_content")}
+ with patch.object(InstagramTool, 'create_s3_client', return_value=s3_client_mock):
+ file_path = "path/to/image.jpg"
+ content = b"image_content"
+ image_url = instagram_tool.get_img_public_url(s3_client_mock, file_path, content)
+ assert image_url.startswith("https://")
+ assert file_path.split("/")[-1] in image_url
+
+def test_get_req_insta_id(instagram_tool):
+ # Test for the get_req_insta_id method
+ # Mock the requests.get method
+ response_mock = MagicMock()
+ response_mock.status_code = 200
+ response_mock.json.return_value = {"instagram_business_account": {"id": "account_id"}}
+ with patch("requests.get", return_value=response_mock):
+ root_api_url = "https://graph.facebook.com/v17.0/"
+ facebook_page_id = "page_id"
+ meta_user_access_token = "access_token"
+ response = instagram_tool.get_req_insta_id(root_api_url, facebook_page_id, meta_user_access_token)
+ assert response.status_code == 200
+ assert response.json()["instagram_business_account"]["id"] == "account_id"
+
+def test_post_media_container_id(instagram_tool):
+ # Test for the post_media_container_id method
+ # Mock the requests.post method
+ response_mock = MagicMock()
+ response_mock.status_code = 200
+ response_mock.json.return_value = {"id": "container_id"}
+ with patch("requests.post", return_value=response_mock):
+ root_api_url = "https://graph.facebook.com/v17.0/"
+ insta_business_account_id = "account_id"
+ image_url = "https://example.com/image.jpg"
+ encoded_caption = "encoded_caption"
+ meta_user_access_token = "access_token"
+ response = instagram_tool.post_media_container_id(root_api_url, insta_business_account_id, image_url, encoded_caption, meta_user_access_token)
+ assert response.status_code == 200
+ assert response.json()["id"] == "container_id"
+
+def test_post_media(instagram_tool):
+ # Test for the post_media method
+ # Mock the requests.post method
+ response_mock = MagicMock()
+ response_mock.status_code = 200
+ with patch("requests.post", return_value=response_mock):
+ root_api_url = "https://graph.facebook.com/v17.0/"
+ insta_business_account_id = "account_id"
+ container_ID = "container_id"
+ meta_user_access_token = "access_token"
+ response = instagram_tool.post_media(root_api_url, insta_business_account_id, container_ID, meta_user_access_token)
+ assert response.status_code == 200
+
diff --git a/tests/unit_tests/tools/instagram_tool/test_instagram_toolkit.py b/tests/unit_tests/tools/instagram_tool/test_instagram_toolkit.py
new file mode 100644
index 000000000..18f73f32c
--- /dev/null
+++ b/tests/unit_tests/tools/instagram_tool/test_instagram_toolkit.py
@@ -0,0 +1,40 @@
+import pytest
+from superagi.tools.instagram_tool.instagram import InstagramTool
+from superagi.tools.instagram_tool.instagram_toolkit import InstagramToolkit
+
+class TestInstagramToolKit:
+ def setup_method(self):
+ """
+ Set up the test fixture.
+
+ This method is called before each test method is executed to prepare the test environment.
+
+ Returns:
+ None
+ """
+ self.toolkit = InstagramToolkit()
+
+ def test_get_tools(self):
+ """
+ Test the `get_tools` method of the `DuckDuckGoToolkit` class.
+
+ It should return a list of tools, containing one instance of `DuckDuckGoSearchTool`.
+
+ Returns:
+ None
+ """
+ tools = self.toolkit.get_tools()
+ assert len(tools) == 1
+ assert isinstance(tools[0], InstagramTool)
+
+ def test_get_env_keys(self):
+ """
+ Test the `get_env_keys` method of the `DuckDuckGoToolkit` class.
+
+ It should return an empty list of environment keys.
+
+ Returns:
+ None
+ """
+ env_keys = self.toolkit.get_env_keys()
+ assert len(env_keys) == 2
diff --git a/tests/unit_tests/tools/knowledge_tool/test_knowledge_search.py b/tests/unit_tests/tools/knowledge_tool/test_knowledge_search.py
new file mode 100644
index 000000000..ec92248ce
--- /dev/null
+++ b/tests/unit_tests/tools/knowledge_tool/test_knowledge_search.py
@@ -0,0 +1,31 @@
+import unittest
+from unittest.mock import Mock, patch
+from superagi.tools.knowledge_search.knowledge_search import KnowledgeSearchTool
+from pydantic.main import BaseModel
+
+class TestKnowledgeSearchTool(unittest.TestCase):
+ def setUp(self):
+ self.tool = KnowledgeSearchTool()
+ self.tool.toolkit_config = Mock(session=Mock())
+ self.tool.agent_id = 1
+
+ @patch('superagi.models.knowledges.Knowledges.get_knowledge_from_id')
+ @patch('superagi.models.agent_config.AgentConfiguration')
+ @patch('superagi.models.toolkit.Toolkit')
+ @patch('superagi.models.vector_db_indices.VectordbIndices.get_vector_index_from_id')
+ @patch('superagi.models.vector_dbs.Vectordbs.get_vector_db_from_id')
+ @patch('superagi.models.vector_db_configs.VectordbConfigs.get_vector_db_config_from_db_id')
+ @patch('superagi.models.configuration.Configuration.fetch_configuration')
+ @patch('superagi.jobs.agent_executor.AgentExecutor.get_embedding')
+
+ def test_execute(self, mock_get_embedding, mock_fetch_configuration, mock_get_vector_db_config_from_db_id, mock_get_vector_db_from_id, mock_get_vector_index_from_id, mock_Toolkit, mock_AgentConfiguration, mock_get_knowledge_from_id):
+ mock_get_embedding.return_value = None
+ mock_AgentConfiguration.filter.first.return_value = Mock(value=None)
+ mock_get_knowledge_from_id.return_value = None
+ result = self.tool._execute(query="test")
+ self.assertEqual(result, "Selected Knowledge not found")
+
+ # other test methods
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py b/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py
new file mode 100644
index 000000000..de447c9c6
--- /dev/null
+++ b/tests/unit_tests/vector_embeddings/test_vector_embedding_factory.py
@@ -0,0 +1,28 @@
+import unittest
+from unittest.mock import patch
+from superagi.vector_embeddings.vector_embedding_factory import VectorEmbeddingFactory
+
+class TestVectorEmbeddingFactory(unittest.TestCase):
+
+ @patch("superagi.vector_embeddings.pinecone.Pinecone.__init__", return_value=None)
+ @patch("superagi.vector_embeddings.qdrant.Qdrant.__init__", return_value=None)
+ def test_build_vector_storge(self, mock_qdrant, mock_pinecone):
+ test_data = {
+ "1": {"id": 1, "embeds": [1,2,3], "text": "test", "chunk": "chunk", "knowledge_name": "knowledge"},
+ "2": {"id": 2, "embeds": [4,5,6], "text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"},
+ }
+
+ vector_storge = VectorEmbeddingFactory.build_vector_storage('Pinecone', test_data)
+
+ mock_pinecone.assert_called_once_with(
+ [1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}]
+ )
+
+ vector_storge = VectorEmbeddingFactory.build_vector_storage('Qdrant', test_data)
+
+ mock_qdrant.assert_called_once_with(
+ [1,2], [[1,2,3],[4,5,6]], [{"text": "test", "chunk": "chunk", "knowledge_name": "knowledge"}, {"text": "test2", "chunk": "chunk2", "knowledge_name": "knowledge2"}]
+ )
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/workspace/input/testing.txt b/workspace/input/testing.txt
deleted file mode 100644
index fe857725e..000000000
--- a/workspace/input/testing.txt
+++ /dev/null
@@ -1 +0,0 @@
-"Hello world"