diff --git a/servers/cogvlm/cogvlm.py b/servers/cogvlm/cogvlm.py index c10c2b3..b1efded 100644 --- a/servers/cogvlm/cogvlm.py +++ b/servers/cogvlm/cogvlm.py @@ -38,10 +38,7 @@ from swarms_cloud.calculate_pricing import calculate_pricing, count_tokens from swarms_cloud.auth_with_swarms_cloud import fetch_api_key_info from swarms_cloud.log_api_request_to_supabase import log_to_supabase, ModelAPILogEntry - -# from exa import calculate_workers -# import torch.distributed as dist - +from exa.structs.parallelize_models_gpus import prepare_model_for_ddp_inference # Load environment variables from .env file load_dotenv() @@ -92,8 +89,9 @@ torch_dtype=torch_type, low_cpu_mem_usage=True, quantization_config=bnb_config, -).eval() +)#.eval() +model = prepare_model_for_ddp_inference(model) # Torch type if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: