Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_spmodel_class to replace getattr for Specify models #5351

Open
wants to merge 2 commits into
base: production
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions specifyweb/specify/tree_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from django.db.models import Count

from specifyweb.businessrules.exceptions import TreeBusinessRuleException
from specifyweb.specify.utils import get_spmodel_class
from . import tree_extras
from . import models as spmodels
from sys import maxsize

import logging
Expand Down Expand Up @@ -117,7 +117,7 @@ def get_tree_item_model(tree_rank_model_name):
tree_item_model_name = TREE_RANK_TO_ITEM_MAP.get(tree_rank_model_name.title(), None)
if not tree_item_model_name:
return None
return getattr(spmodels, tree_item_model_name, None)
return get_spmodel_class(tree_item_model_name)

def tree_rank_count(tree_rank_model_name, tree_rank_id) -> int:
tree_item_model = get_tree_item_model(tree_rank_model_name)
Expand Down Expand Up @@ -186,14 +186,14 @@ def set_rank_id(new_rank):

# Get tree def item model
tree_def_item_model_name = (tree + 'treedefitem').lower().title()
tree_def_item_model = getattr(spmodels, tree_def_item_model_name)
tree_def_item_model = get_spmodel_class(tree_def_item_model_name)

# Handle case where the parent rank is not given, and it is not the first rank added.
# This is happening in the UI workflow of Treeview->Treedef->Treedefitems->Add
if (
new_rank.parent is None
and new_rank.rankid is None
and getattr(spmodels, new_rank.specify_model.django_name).objects.filter(treedef=tree_def).count() > 1
and get_spmodel_class(new_rank.specify_model.django_name).objects.filter(treedef=tree_def).count() > 1
):
new_rank.parent = tree_def_item_model.objects.filter(treedef=tree_def).order_by("rankid").last()
parent_rank_name = new_rank.parent.name
Expand Down Expand Up @@ -308,7 +308,7 @@ def verify_rank_parent_chain_integrity(rank, rank_operation: RankOperation):
"""
tree_def = rank.treedef
tree_def_item_model_name = rank.specify_model.name.lower().title()
tree_def_item_model = getattr(spmodels, tree_def_item_model_name)
tree_def_item_model = get_spmodel_class(tree_def_item_model_name)

# Get all the ranks and their parent ranks
rank_id_to_parent_dict = {item.id: item.parent.id if item.parent is not None else None
Expand Down
9 changes: 5 additions & 4 deletions specifyweb/specify/tree_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Tuple, List
from django.db.models import Q, Count, Model
import specifyweb.specify.models as spmodels
from specifyweb.specify.models import Collection
from specifyweb.specify.datamodel import datamodel
from specifyweb.specify.utils import get_spmodel_class

lookup = lambda tree: (tree.lower() + 'treedef')

SPECIFY_TREES = {"taxon", "storage", "geography", "geologictimeperiod", "lithostrat", 'tectonicunit'}

def get_search_filters(collection: spmodels.Collection, tree: str):
def get_search_filters(collection: Collection, tree: str):
tree_name = tree.lower()
if tree_name == 'storage':
return Q(institution=collection.discipline.division.institution)
Expand All @@ -22,7 +23,7 @@ def get_search_filters(collection: spmodels.Collection, tree: str):
discipline_query |= Q(id=tree_at_discipline.id)
return discipline_query

def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple[int, int]]:
def get_treedefs(collection: Collection, tree_name: str) -> List[Tuple[int, int]]:
# Get the appropriate TreeDef based on the Collection and tree_name

# Mimic the old behavior of limiting the query to the first item for trees other than taxon.
Expand All @@ -32,7 +33,7 @@ def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple

lookup_tree = lookup(tree_name)
tree_table = datamodel.get_table_strict(lookup_tree)
tree_model: Model = getattr(spmodels, tree_table.django_name)
tree_model: Model = get_spmodel_class(tree_table.django_name)

# Get all the treedefids, and the count of item in each, corresponding to our search predicates
search_query = _limit(
Expand Down
23 changes: 23 additions & 0 deletions specifyweb/specify/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import logging
from specifyweb.accounts import models as acccounts_models
from specifyweb.attachment_gw import models as attachment_gw_models
from specifyweb.businessrules import models as businessrules_models
Expand All @@ -10,6 +11,8 @@
from specifyweb.specify import models as spmodels
from django.conf import settings

logger = logging.getLogger(__name__)

APP_MODELS = [spmodels, acccounts_models, attachment_gw_models, businessrules_models, context_models,
notifications_models, permissions_models, interactions_models, workbench_models]

Expand All @@ -18,3 +21,23 @@ def get_app_model(model_name: str):
if hasattr(app, model_name):
return getattr(app, model_name)
return None

def get_spmodel_class(model_name: str):
try:
return getattr(spmodels, model_name.capitalize())
except AttributeError:
pass
# Iterate over all attributes in the models module
for attr_name in dir(spmodels):
# Check if the attribute name matches the model name case-insensitively
if attr_name.lower() == model_name.lower():
return getattr(spmodels, attr_name)
raise AttributeError(f"Model '{model_name}' not found in models module.")

def log_sqlalchemy_query(query):
from sqlalchemy.dialects import mysql
compiled_query = query.statement.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True})
raw_sql = str(compiled_query)
logger.debug(raw_sql)
# Run in the storred_queries.execute file, in the execute function, right before the return statement, line 546
# from specifyweb.specify.utils import log_sqlalchemy_query; log_sqlalchemy_query(query)
24 changes: 13 additions & 11 deletions specifyweb/specify/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from specifyweb.celery_tasks import app, CELERY_TASK_STATE
from specifyweb.specify.record_merging import record_merge_fx, record_merge_task, resolve_record_merge_response
from specifyweb.specify.update_locality import localityupdate_parse_success, localityupdate_parse_error, parse_locality_set as _parse_locality_set, upload_locality_set as _upload_locality_set, create_localityupdate_recordset, update_locality_task, parse_locality_task, LocalityUpdateStatus
from . import api, models as spmodels
from specifyweb.specify.utils import get_spmodel_class
from specifyweb.specify import api
from specifyweb.specify.models import Agent, Collection, Division, Specifyuser
from .specify_jar import specify_jar


Expand Down Expand Up @@ -190,7 +192,7 @@ def set_password(request, userid):
"""
check_permission_targets(None, request.specify_user.id, [
SetPasswordPT.update])
user = spmodels.Specifyuser.objects.get(pk=userid)
user = Specifyuser.objects.get(pk=userid)
user.set_password(request.POST['password'])
user.save()
return http.HttpResponse('', status=204)
Expand Down Expand Up @@ -304,23 +306,23 @@ class SetUserAgentsPT(PermissionTarget):
@require_POST
def set_user_agents(request, userid: int):
"Sets the agents to represent the user in different disciplines."
user = spmodels.Specifyuser.objects.get(pk=userid)
user = Specifyuser.objects.get(pk=userid)
new_agentids = json.loads(request.body)
cursor = connection.cursor()

with transaction.atomic():
# clear user's existing agents
spmodels.Agent.objects.filter(
Agent.objects.filter(
specifyuser_id=userid).update(specifyuser_id=None)

# check if any of the agents to be assigned are used by other users
in_use = spmodels.Agent.objects.select_for_update().filter(
in_use = Agent.objects.select_for_update().filter(
pk__in=new_agentids, specifyuser_id__isnull=False)
if in_use:
raise AgentInUseException([a.id for a in in_use])

# assign the new agents
spmodels.Agent.objects.filter(
Agent.objects.filter(
pk__in=new_agentids).update(specifyuser_id=userid)

# check for multiple agents assigned to the user
Expand All @@ -339,7 +341,7 @@ def set_user_agents(request, userid: int):
raise MultipleAgentsException(multiple)

# get the list of collections the agents belong to.
collections = spmodels.Collection.objects.filter(
collections = Collection.objects.filter(
discipline__division__members__specifyuser_id=userid).values_list('id', flat=True)

# check permissions for setting user agents in those collections.
Expand All @@ -356,7 +358,7 @@ def check_collection_access_against_agents(userid: int) -> None:
from specifyweb.context.views import users_collections_for_sp6, users_collections_for_sp7

# get the list of collections the agents belong to.
collections = spmodels.Collection.objects.filter(
collections = Collection.objects.filter(
discipline__division__members__specifyuser_id=userid).values_list('id', flat=True)

# make sure every collection the user is permitted to access has an assigned user.
Expand All @@ -373,7 +375,7 @@ def check_collection_access_against_agents(userid: int) -> None:
if collection.id not in collections
]
if missing_for_6 or missing_for_7:
all_divisions = spmodels.Division.objects.filter(
all_divisions = Division.objects.filter(
disciplines__collections__id__in=[
cid for cid, _ in sp6_collections] + [c.id for c in sp7_collections]
).values_list('id', flat=True).distinct()
Expand Down Expand Up @@ -426,7 +428,7 @@ def set_admin_status(request, userid):
"""
check_permission_targets(
None, request.specify_user.id, [Sp6AdminPT.update])
user = spmodels.Specifyuser.objects.get(pk=userid)
user = Specifyuser.objects.get(pk=userid)
if request.POST['admin_status'] == 'true':
user.set_admin()
return http.HttpResponse('true', content_type='text/plain')
Expand Down Expand Up @@ -499,7 +501,7 @@ def record_merge(
"""Replaces all the foreign keys referencing the old record IDs
with the new record ID, and deletes the old records.
"""
record_version = getattr(spmodels, model_name.title()).objects.get(
record_version = get_spmodel_class(model_name.title()).objects.get(
id=new_model_id).version
get_version = request.GET.get('version', record_version)
version = get_version if isinstance(get_version, int) else 0
Expand Down
9 changes: 5 additions & 4 deletions specifyweb/stored_queries/query_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from sqlalchemy import orm, sql

import specifyweb.specify.models as spmodels
from specifyweb.specify.datamodel import datamodel
from specifyweb.specify.tree_utils import get_treedefs

from specifyweb.specify.utils import get_spmodel_class
from specifyweb.stored_queries import models

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,7 +61,7 @@ def handle_tree_field(self, node, table, tree_rank, tree_field):
query = query._replace(join_cache=query.join_cache.copy())
query.join_cache[(table, 'TreeRanks')] = (ancestors, treedefs)

item_model = getattr(spmodels, table.django_name + "treedefitem")
item_model = get_spmodel_class(table.django_name + "treedefitem")

# TODO: optimize out the ranks that appear? cache them
treedefs_with_ranks: List[Tuple[int, int]] = [tup for tup in [
Expand Down Expand Up @@ -105,7 +106,7 @@ def tables_in_path(self, table, join_path):
if not field.is_relationship:
break

tables.append(spmodels.datamodel.get_table(field.relatedModelName, strict=True))
tables.append(datamodel.get_table(field.relatedModelName, strict=True))
return tables

def build_join(self, table, model, join_path):
Expand All @@ -119,7 +120,7 @@ def build_join(self, table, model, join_path):

if not field.is_relationship:
break
next_table = spmodels.datamodel.get_table(field.relatedModelName, strict=True)
next_table = datamodel.get_table(field.relatedModelName, strict=True)
logger.debug("joining: %r to %r via %r", table, next_table, field)
if (model, field.name) in query.join_cache:
aliased = query.join_cache[(model, field.name)]
Expand Down
Loading