diff --git a/vizier/__init__.py b/vizier/__init__.py index f190dd32b..6ceaf4db6 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -23,4 +23,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.1.19" +__version__ = "0.1.20" diff --git a/vizier/_src/service/sql_datastore.py b/vizier/_src/service/sql_datastore.py index cac07b9fd..094282d59 100644 --- a/vizier/_src/service/sql_datastore.py +++ b/vizier/_src/service/sql_datastore.py @@ -40,7 +40,7 @@ class SQLDataStore(datastore.DataStore): """SQL Datastore.""" - def __init__(self, engine): + def __init__(self, engine: sqla.engine.Engine): self._engine = engine self._connection = self._engine.connect() self._root_metadata = sqla.MetaData() @@ -104,12 +104,16 @@ def create_study(self, study: study_pb2.Study) -> resources.StudyResource: with self._lock: try: self._connection.execute(owner_query) + self._connection.commit() except sqla.exc.IntegrityError: logging.info('Owner with name %s currently exists.', owner_name) + self._connection.rollback() try: self._connection.execute(study_query) + self._connection.commit() return study_resource except sqla.exc.IntegrityError as e: + self._connection.rollback() raise AlreadyExistsError( 'Study with name %s already exists.' % study.name ) from e @@ -148,6 +152,7 @@ def update_study(self, study: study_pb2.Study) -> resources.StudyResource: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Study %s does not exist.' % study.name) self._connection.execute(uq) + self._connection.commit() return study_resource def delete_study(self, study_name: str) -> None: @@ -172,6 +177,7 @@ def delete_study(self, study_name: str) -> None: raise NotFoundError('Study %s does not exist.' % study_name) self._connection.execute(dsq) self._connection.execute(dtq) + self._connection.commit() def list_studies(self, owner_name: str) -> List[study_pb2.Study]: owner_id = resources.OwnerResource.from_name(owner_name).owner_id @@ -205,8 +211,10 @@ def create_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: with self._lock: try: self._connection.execute(query) + self._connection.commit() return trial_resource except sqla.exc.IntegrityError as e: + self._connection.rollback() raise AlreadyExistsError( 'Trial with name %s already exists.' % trial.name ) from e @@ -246,6 +254,7 @@ def update_trial(self, trial: study_pb2.Trial) -> resources.TrialResource: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Trial %s does not exist.' % trial.name) self._connection.execute(uq) + self._connection.commit() return trial_resource @@ -283,6 +292,7 @@ def delete_trial(self, trial_name: str) -> None: if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Trial %s does not exist.' % trial_name) self._connection.execute(dq) + self._connection.commit() def max_trial_id(self, study_name: str) -> int: study_resource = resources.StudyResource.from_name(study_name) @@ -323,8 +333,10 @@ def create_suggestion_operation( try: with self._lock: self._connection.execute(query) + self._connection.commit() return resource except sqla.exc.IntegrityError as e: + self._connection.rollback() raise AlreadyExistsError( 'Suggest Op with name %s already exists.' % operation.name ) from e @@ -375,6 +387,7 @@ def update_suggestion_operation( if not self._connection.execute(eq).fetchone()[0]: raise NotFoundError('Suggest op %s does not exist.' % operation.name) self._connection.execute(uq) + self._connection.commit() return resource def list_suggestion_operations( @@ -464,8 +477,10 @@ def create_early_stopping_operation( try: with self._lock: self._connection.execute(query) + self._connection.commit() return resource except sqla.exc.IntegrityError as e: + self._connection.rollback() raise AlreadyExistsError( 'Early stopping op with name %s already exists.' % operation.name ) from e @@ -521,6 +536,7 @@ def update_early_stopping_operation( 'Early stopping op %s does not exist.' % operation.name ) self._connection.execute(uq) + self._connection.commit() return resource def update_metadata( @@ -552,6 +568,7 @@ def update_metadata( usq = usq.where(self._studies_table.c.study_name == study_name) usq = usq.values(serialized_study=original_study.SerializeToString()) self._connection.execute(usq) + self._connection.commit() # Split the trial-related metadata by Trial. split_metadata = collections.defaultdict(list) @@ -578,3 +595,4 @@ def update_metadata( utq = utq.where(self._trials_table.c.trial_name == trial_name) utq = utq.values(serialized_trial=original_trial.SerializeToString()) self._connection.execute(utq) + self._connection.commit() diff --git a/vizier/_src/service/sql_datastore_test.py b/vizier/_src/service/sql_datastore_test.py index 317a8a68e..b793f38a0 100644 --- a/vizier/_src/service/sql_datastore_test.py +++ b/vizier/_src/service/sql_datastore_test.py @@ -15,9 +15,9 @@ from __future__ import annotations """Tests for sql_datastore.""" + import os import sqlalchemy as sqla - from vizier._src.service import constants from vizier._src.service import datastore_test_lib from vizier._src.service import sql_datastore @@ -46,7 +46,9 @@ def setUp(self): ) ) - engine = sqla.create_engine(constants.SQL_MEMORY_URL, echo=True) + engine = sqla.create_engine( + constants.SQL_MEMORY_URL, echo=True, future=True + ) self.datastore = sql_datastore.SQLDataStore(engine) super().setUp() diff --git a/vizier/_src/service/vizier_service.py b/vizier/_src/service/vizier_service.py index d65fa122b..92a9a710e 100644 --- a/vizier/_src/service/vizier_service.py +++ b/vizier/_src/service/vizier_service.py @@ -104,8 +104,9 @@ def __init__( else: engine = sqla.create_engine( database_url, - echo=False, # Set True to log transactions for debugging. connect_args={'check_same_thread': False}, + echo=False, # Set True to log transactions for debugging. + future=True, # Backward compatibility with sqlalchemy 1.4. poolclass=sqla.pool.StaticPool, ) self.datastore = sql_datastore.SQLDataStore(engine)