Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 20, 2024
1 parent 74eb7c4 commit fa5bec5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
34 changes: 13 additions & 21 deletions servers/cogvlm/cogvlm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dotenv import load_dotenv
import base64
import os
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -39,6 +40,10 @@
)
from swarms_cloud.utils.count_cores_for_workers import calculate_workers

# Load environment variables from .env file
load_dotenv()


# Environment variables
MODEL_PATH = os.environ.get("COGVLM_MODEL_PATH", "THUDM/cogvlm-chat-hf")
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", "lmsys/vicuna-7b-v1.5")
Expand Down Expand Up @@ -96,11 +101,9 @@ async def load_model():

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
# load_in_4bit=True,
trust_remote_code=True,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
# attn_implementation="flash_attention_2",
quantization_config=bnb_config,
).eval()

Expand Down Expand Up @@ -138,19 +141,10 @@ async def list_models():
async def create_chat_completion(
request: ChatCompletionRequest, token: str = Depends(authenticate_user)
):
# global model, tokenizer

if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")

# Calculate pricing
out = calculate_pricing(
texts=[
message.content for message in request.messages if message.role == "user"
],
tokenizer=tokenizer,
rate_per_million=15.0,
)

print(f"Request: {request}")

gen_params = dict(
Expand Down Expand Up @@ -179,17 +173,17 @@ async def create_chat_completion(

# Log the entry to supabase
entry = ModelAPILogEntry(
user_id=fetch_api_key_info(token),
user_id= await fetch_api_key_info(token),
model_id="41a2869c-5f8d-403f-83bb-1f06c56bad47",
input_tokens=count_tokens(request.messsages, tokenizer, request.model),
output_tokens=count_tokens(response["text"], tokenizer, request.model),
all_cost=calculate_pricing(
input_tokens= await count_tokens(request.messsages, tokenizer, request.model),
output_tokens= await count_tokens(response["text"], tokenizer, request.model),
all_cost= await calculate_pricing(
texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
),
input_cost=calculate_pricing(
input_cost = await calculate_pricing(
texts=[message.content], tokenizer=tokenizer, rate_per_million=15.0
),
output_cost=calculate_pricing(
output_cost= await calculate_pricing(
texts=response["text"], tokenizer=tokenizer, rate_per_million=15.0
)
* 5,
Expand All @@ -203,7 +197,7 @@ async def create_chat_completion(
)

# Log the entry to supabase
log_to_supabase(entry=entry)
await log_to_supabase(entry=entry)

# ChatCompletionResponseChoice
logger.debug(f"==== message ====\n{message}")
Expand Down Expand Up @@ -233,8 +227,6 @@ async def predict(model_id: str, params: dict):
This is particularly useful for real-time, continuous interactions with the model.
"""

global model, tokenizer

choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
)
Expand Down
4 changes: 2 additions & 2 deletions swarms_cloud/auth_with_swarms_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def fetch_api_key_info(token: str, supabase: Client = supabase_client_init
return None


async def authenticate_user(
def authenticate_user(
credentials: HTTPAuthorizationCredentials = Depends(http_bearer),
):
"""
Expand All @@ -91,7 +91,7 @@ async def authenticate_user(
if not is_token_valid(token, supabase_client_init):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
detail="Invalid token. Please authenticate with a valid token at https://swarms.world/dashboard",
headers={"WWW-Authenticate": "Bearer"},
)
return token
2 changes: 1 addition & 1 deletion swarms_cloud/utils/count_cores_for_workers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing


def calculate_workers():
def calculate_workers() -> int:
"""
Calculates the number of workers based on the number of CPU cores.
Expand Down

0 comments on commit fa5bec5

Please sign in to comment.