Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blackify #133

Merged
merged 1 commit into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datatables/clean_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ def clean_regex(regex):

# these characters are escaped (all except alternation | and escape \)
# see http://www.regular-expressions.info/refquick.html
escape_chars = '[^$.?*+(){}'
escape_chars = "[^$.?*+(){}"

# remove any escape chars
ret_regex = ret_regex.replace('\\', '')
ret_regex = ret_regex.replace("\\", "")

# escape any characters which are used by regex
# could probably concoct something incomprehensible using re.sub() but
# prefer to write clear code with this loop
# note expectation that no characters have already been escaped
for c in escape_chars:
ret_regex = ret_regex.replace(c, '\\' + c)
ret_regex = ret_regex.replace(c, "\\" + c)

# remove any double alternations until these don't exist any more
while True:
old_regex = ret_regex
ret_regex = ret_regex.replace('||', '|')
ret_regex = ret_regex.replace("||", "|")
if old_regex == ret_regex:
break

# if last char is alternation | remove it because this
# will cause operational error
# this can happen as user is typing in global search box
while len(ret_regex) >= 1 and ret_regex[-1] == '|':
while len(ret_regex) >= 1 and ret_regex[-1] == "|":
ret_regex = ret_regex[:-1]

# and back to the caller
Expand Down
43 changes: 23 additions & 20 deletions datatables/column_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

from datatables.search_methods import SEARCH_METHODS

NULLS_ORDER = ['nullsfirst', 'nullslast']
NULLS_ORDER = ["nullsfirst", "nullslast"]

ColumnTuple = namedtuple('ColumnDT', [
'sqla_expr',
'column_name',
'mData',
'search_method',
'nulls_order',
'global_search',
])
ColumnTuple = namedtuple(
"ColumnDT",
[
"sqla_expr",
"column_name",
"mData",
"search_method",
"nulls_order",
"global_search",
],
)


class ColumnDT(ColumnTuple):
Expand Down Expand Up @@ -57,24 +60,24 @@ class ColumnDT(ColumnTuple):
"""

def __new__(
cls,
sqla_expr,
column_name=None,
mData=None,
search_method='string_contains',
nulls_order=None,
global_search=True,
cls,
sqla_expr,
column_name=None,
mData=None,
search_method="string_contains",
nulls_order=None,
global_search=True,
):
"""Set default values due to namedtuple immutability."""
if nulls_order and nulls_order not in NULLS_ORDER:
raise ValueError(
'{} is not an allowed value for nulls_order.'.format(
nulls_order))
"{} is not an allowed value for nulls_order.".format(nulls_order)
)

if search_method not in SEARCH_METHODS:
raise ValueError(
'{} is not an allowed value for search_method.'.format(
search_method))
"{} is not an allowed value for search_method.".format(search_method)
)

return super(ColumnDT, cls).__new__(
cls,
Expand Down
119 changes: 59 additions & 60 deletions datatables/datatables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ class DataTables:
def __init__(self, request, query, columns, allow_regex_searches=False):
"""Initialize object and run the query."""
self.params = dict(request)
if 'sEcho' in self.params:
raise ValueError(
'Legacy datatables not supported, upgrade to >=1.10')
if "sEcho" in self.params:
raise ValueError("Legacy datatables not supported, upgrade to >=1.10")
self.query = query
self.columns = columns
self.results = None
Expand All @@ -51,40 +50,49 @@ def __init__(self, request, query, columns, allow_regex_searches=False):
def output_result(self):
"""Output results in the format needed by DataTables."""
output = {}
output['draw'] = str(int(self.params.get('draw', 1)))
output['recordsTotal'] = str(self.cardinality)
output['recordsFiltered'] = str(self.cardinality_filtered)
output["draw"] = str(int(self.params.get("draw", 1)))
output["recordsTotal"] = str(self.cardinality)
output["recordsFiltered"] = str(self.cardinality_filtered)
if self.error:
output['error'] = self.error
output["error"] = self.error
return output

output['data'] = self.results
output["data"] = self.results
for k, v in self.yadcf_params:
output[k] = v
return output

def _query_with_all_filters_except_one(self, query, exclude):
return query.filter(*[
e for i, e in enumerate(self.filter_expressions)
if e is not None and i is not exclude
])
return query.filter(
*[
e
for i, e in enumerate(self.filter_expressions)
if e is not None and i is not exclude
]
)

def _set_yadcf_data(self, query):
# determine values for yadcf filters
for i, col in enumerate(self.columns):
if col.search_method in 'yadcf_range_number_slider':
if col.search_method in "yadcf_range_number_slider":
v = query.add_columns(
func.min(col.sqla_expr), func.max(col.sqla_expr)).one()
self.yadcf_params.append(('yadcf_data_{:d}'.format(i),
(math.floor(v[0]), math.ceil(v[1]))))
func.min(col.sqla_expr), func.max(col.sqla_expr)
).one()
self.yadcf_params.append(
("yadcf_data_{:d}".format(i), (math.floor(v[0]), math.ceil(v[1])))
)
if col.search_method in [
'yadcf_select', 'yadcf_multi_select', 'yadcf_autocomplete'
"yadcf_select",
"yadcf_multi_select",
"yadcf_autocomplete",
]:
filtered = self._query_with_all_filters_except_one(
query=query, exclude=i)
query=query, exclude=i
)
v = filtered.add_columns(col.sqla_expr).distinct().all()
self.yadcf_params.append(('yadcf_data_{:d}'.format(i),
[r[0] for r in v]))
self.yadcf_params.append(
("yadcf_data_{:d}".format(i), [r[0] for r in v])
)

def run(self):
"""Launch filtering, sorting and paging to output results."""
Expand All @@ -99,38 +107,33 @@ def run(self):
self._set_yadcf_data(query)

# apply filters
query = query.filter(
*[e for e in self.filter_expressions if e is not None])
query = query.filter(*[e for e in self.filter_expressions if e is not None])

self.cardinality_filtered = query.add_columns(
self.columns[0].sqla_expr).count()
self.cardinality_filtered = query.add_columns(self.columns[0].sqla_expr).count()

# apply sorts
query = query.order_by(
*[e for e in self.sort_expressions if e is not None])
query = query.order_by(*[e for e in self.sort_expressions if e is not None])

# add paging options
length = int(self.params.get('length'))
length = int(self.params.get("length"))
if length >= 0:
query = query.limit(length)
elif length == -1:
pass
else:
raise (ValueError(
'Length should be a positive integer or -1 to disable'))
query = query.offset(int(self.params.get('start')))
raise (ValueError("Length should be a positive integer or -1 to disable"))
query = query.offset(int(self.params.get("start")))

# add columns to query
query = query.add_columns(*[c.sqla_expr for c in self.columns])

# fetch the result of the queries
column_names = [
col.mData if col.mData else str(i)
for i, col in enumerate(self.columns)
col.mData if col.mData else str(i) for i, col in enumerate(self.columns)
]
self.results = [
{k: v for k, v in zip(column_names, row)} for row in query.all()
]
self.results = [{k: v
for k, v in zip(column_names, row)}
for row in query.all()]

def _set_column_filter_expressions(self):
"""Construct the query: filtering.
Expand All @@ -140,35 +143,32 @@ def _set_column_filter_expressions(self):
# per columns filters:
for i in range(len(self.columns)):
filter_expr = None
value = self.params.get('columns[{:d}][search][value]'.format(i),
'')
value = self.params.get("columns[{:d}][search][value]".format(i), "")
if value:
search_func = SEARCH_METHODS[self.columns[i].search_method]
filter_expr = search_func(self.columns[i].sqla_expr, value)
self.filter_expressions.append(filter_expr)

def _set_global_filter_expression(self):
# global search filter
global_search = self.params.get('search[value]', '')
if global_search == '':
global_search = self.params.get("search[value]", "")
if global_search == "":
return

if (self.allow_regex_searches
and self.params.get('search[regex]') == 'true'):
if self.allow_regex_searches and self.params.get("search[regex]") == "true":
op = self._get_regex_operator()
val = clean_regex(global_search)

def filter_for(col):
return col.sqla_expr.op(op)(val)

else:
val = '%' + global_search + '%'
val = "%" + global_search + "%"

def filter_for(col):
return col.sqla_expr.cast(Text).ilike(val)

global_filter = [
filter_for(col) for col in self.columns if col.global_search
]
global_filter = [filter_for(col) for col in self.columns if col.global_search]

self.filter_expressions.append(or_(*global_filter))

Expand All @@ -179,38 +179,37 @@ def _set_sort_expressions(self):
"""
sort_expressions = []
i = 0
while self.params.get('order[{:d}][column]'.format(i), False):
column_nr = int(self.params.get('order[{:d}][column]'.format(i)))
while self.params.get("order[{:d}][column]".format(i), False):
column_nr = int(self.params.get("order[{:d}][column]".format(i)))
column = self.columns[column_nr]
direction = self.params.get('order[{:d}][dir]'.format(i))
direction = self.params.get("order[{:d}][dir]".format(i))
sort_expr = column.sqla_expr
if direction == 'asc':
if direction == "asc":
sort_expr = sort_expr.asc()
elif direction == 'desc':
elif direction == "desc":
sort_expr = sort_expr.desc()
else:
raise ValueError(
'Invalid order direction: {}'.format(direction))
raise ValueError("Invalid order direction: {}".format(direction))
if column.nulls_order:
if column.nulls_order == 'nullsfirst':
if column.nulls_order == "nullsfirst":
sort_expr = sort_expr.nullsfirst()
elif column.nulls_order == 'nullslast':
elif column.nulls_order == "nullslast":
sort_expr = sort_expr.nullslast()
else:
raise ValueError(
'Invalid order direction: {}'.format(direction))
raise ValueError("Invalid order direction: {}".format(direction))

sort_expressions.append(sort_expr)
i += 1
self.sort_expressions = sort_expressions

def _get_regex_operator(self):
if isinstance(self.query.session.bind.dialect, postgresql.dialect):
return '~'
return "~"
elif isinstance(self.query.session.bind.dialect, mysql.dialect):
return 'REGEXP'
return "REGEXP"
elif isinstance(self.query.session.bind.dialect, sqlite.dialect):
return 'REGEXP'
return "REGEXP"
else:
raise NotImplementedError(
'Regex searches are not implemented for this dialect')
"Regex searches are not implemented for this dialect"
)
Loading