diff --git a/config.py b/config.py new file mode 100644 index 0000000..3b03be1 --- /dev/null +++ b/config.py @@ -0,0 +1,18 @@ +import os +def post_fork(server, worker): + server.log.info("Worker spawned (pid: %s)", worker.pid) + cuda_device_count = int(os.getenv("APP_CUDA_DEVICE_COUNT", -1)) + + if cuda_device_count > 0: + # set variables for cuda resource allocation + # Needs to be done before loading models + # The number of devices to use should be set via + # APP_CUDA_DEVICE_COUNT in env_app and the docker compose + # file should allocate cards to the container + cudaid = worker.age % cuda_device_count + worker.log.info("Setting cuda device " + str(cudaid)) + os.environ["CUDA_VISIBLE_DEVICES"] = str(cudaid) + else: + worker.log.info("APP_CUDA_DEVICE_COUNT device variables not set") + + diff --git a/medcat_service/nlp_processor/medcat_processor.py b/medcat_service/nlp_processor/medcat_processor.py index 111cee0..3ad8849 100644 --- a/medcat_service/nlp_processor/medcat_processor.py +++ b/medcat_service/nlp_processor/medcat_processor.py @@ -61,6 +61,7 @@ def __init__(self): self.app_model = os.getenv("APP_MODEL_NAME", "unknown") self.entity_output_mode = os.getenv("ANNOTATIONS_ENTITY_OUTPUT_MODE", "dict").lower() + self.cat = self._create_cat() self.cat.train = os.getenv("APP_TRAINING_MODE", False) @@ -70,11 +71,13 @@ def __init__(self): # this is available to constrain torch threads when there # isn't a GPU # You probably want to set to 1 + # Not sure what happens if torch is using a cuda device if self.torch_threads > 0: import torch torch.set_num_threads(self.torch_threads) self.log.info("Torch threads set to " + str(self.torch_threads)) + self.log.info("MedCAT processor is ready") def get_app_info(self): diff --git a/start-service-prod.sh b/start-service-prod.sh index f0c3468..ef8fbd9 100644 --- a/start-service-prod.sh +++ b/start-service-prod.sh @@ -34,5 +34,6 @@ SERVER_ACCESS_LOG_FORMAT="%(t)s [ACCESSS] %(h)s \"%(r)s\" %(s)s \"%(f)s\" \"%(a) # echo "Starting up Flask app using gunicorn server ..." gunicorn --bind $SERVER_HOST:$SERVER_PORT --workers=$SERVER_WORKERS --threads=$SERVER_THREADS --timeout=$SERVER_WORKER_TIMEOUT \ - --access-logformat="$SERVER_ACCESS_LOG_FORMAT" --access-logfile=- --log-file=- --log-level info \ + --access-logformat="$SERVER_ACCESS_LOG_FORMAT" --access-logfile=- --log-file=- --log-level info \ + --config /cat/config.py \ wsgi