Skip to content

Commit

Permalink
[feat]: Adapt GCPCluster to be able to initialize the cluster by pass…
Browse files Browse the repository at this point in the history
…ing credentials without obtaining from the environment
  • Loading branch information
rubenbblazquez committed Aug 28, 2024
1 parent 5babf64 commit 64dfc3c
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions dask_cloudprovider/gcp/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json

import sqlite3
from typing import Optional, Any

import dask
from dask.utils import tmpfile
Expand Down Expand Up @@ -106,7 +107,6 @@ def __init__(
self.instance_labels = _instance_labels

self.general_zone = "-".join(self.zone.split("-")[:2]) # us-east1-c -> us-east1

self.service_account = service_account or self.config.get("service_account")

def create_gcp_config(self):
Expand Down Expand Up @@ -494,6 +494,8 @@ class GCPCluster(VMCluster):
service_account: str
Service account that all VMs will run under.
Defaults to the default Compute Engine service account for your GCP project.
service_account_credentials: Optional[dict[str, Any]]
Service account credentials to create the compute engine Vms
Examples
--------
Expand Down Expand Up @@ -587,9 +589,10 @@ def __init__(
debug=False,
instance_labels=None,
service_account=None,
service_account_credentials: Optional[dict[str, Any]]=None,
**kwargs,
):
self.compute = GCPCompute()
self.compute = GCPCompute(service_account_credentials)

self.config = dask.config.get("cloudprovider.gcp", {})
self.auto_shutdown = (
Expand Down Expand Up @@ -641,9 +644,17 @@ def __init__(


class GCPCompute:
"""Wrapper for the ``googleapiclient`` compute object."""
"""
Wrapper for the ``googleapiclient`` compute object.
Attributes
----------
service_account_credentials: Optional[dict]
Service account credentials to create the compute engine Vms
"""

def __init__(self):
def __init__(self, service_account_credentials: Optional[dict[str, Any]] = None):
self.service_account_credentials = service_account_credentials or {}
self._compute = self.refresh_client()

def refresh_client(self):
Expand All @@ -654,6 +665,13 @@ def refresh_client(self):
os.environ["GOOGLE_APPLICATION_CREDENTIALS"],
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
elif self.service_account_credentials:
import google.oauth2.service_account # google-auth

creds = google.oauth2.service_account.Credentials.from_service_account_info(
self.service_account_credentials,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
import google.auth.credentials # google-auth

Expand Down

0 comments on commit 64dfc3c

Please sign in to comment.