Skip to content

Commit

Permalink
Merge pull request #173 from sfu-db/feat/pagination
Browse files Browse the repository at this point in the history
Implement pagination for data_connector
  • Loading branch information
dovahcrow authored Jun 1, 2020
2 parents a06c93b + b1ced65 commit 3d2e8b7
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 53 deletions.
177 changes: 125 additions & 52 deletions dataprep/data_connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module contains the Connector class.
Every data fetching action should begin with instantiating this Connector class.
"""

import math
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand All @@ -12,10 +12,9 @@

from ..errors import UnreachableError
from .config_manager import config_directory, ensure_config
from .errors import RequestError
from .errors import RequestError, UniversalParameterOverridden
from .implicit_database import ImplicitDatabase, ImplicitTable


INFO_TEMPLATE = Template(
"""{% for tb in tbs.keys() %}
Table {{dbname}}.{{tb}}
Expand Down Expand Up @@ -57,15 +56,15 @@ class Connector:

_impdb: ImplicitDatabase
_vars: Dict[str, Any]
_auth_params: Dict[str, Any]
_auth: Dict[str, Any]
_session: Session
_jenv: Environment

def __init__(
self,
config_path: str,
auth_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
_auth: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> None:
self._session = Session()
if (
Expand All @@ -82,33 +81,41 @@ def __init__(
self._impdb = ImplicitDatabase(path)

self._vars = kwargs
self._auth_params = auth_params or {}
self._auth = _auth or {}
self._jenv = Environment(undefined=StrictUndefined)

def _fetch(
def _fetch( # pylint: disable=too-many-locals,too-many-branches
self,
table: ImplicitTable,
auth_params: Optional[Dict[str, Any]],
*,
_count: Optional[int] = None,
_cursor: Optional[int] = None,
_auth: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any],
) -> Response:
assert (_count is None) == (
_cursor is None
), "_cursor and _count should both be None or not None"

method = table.method
url = table.url
req_data: Dict[str, Dict[str, Any]] = {
"headers": {},
"params": {},
"cookies": {},
}

merged_vars = {**self._vars, **kwargs}

if table.authorization is not None:
table.authorization.build(req_data, auth_params or self._auth_params)
table.authorization.build(req_data, _auth or self._auth)

for key in ["headers", "params", "cookies"]:
if getattr(table, key) is not None:
instantiated_fields = getattr(table, key).populate(
self._jenv, merged_vars
)
req_data[key].update(**instantiated_fields)

if table.body is not None:
# TODO: do we support binary body?
instantiated_fields = table.body.populate(self._jenv, merged_vars)
Expand All @@ -119,6 +126,26 @@ def _fetch(
else:
raise UnreachableError

if table.pag_params is not None and _count is not None:
pag_type = table.pag_params.type
count_key = table.pag_params.count_key
if pag_type == "cursor":
assert table.pag_params.cursor_key is not None
cursor_key = table.pag_params.cursor_key
elif pag_type == "limit":
assert table.pag_params.anchor_key is not None
cursor_key = table.pag_params.anchor_key
else:
raise UnreachableError()

if count_key in req_data["params"]:
raise UniversalParameterOverridden(count_key, "_count")
req_data["params"][count_key] = _count

if cursor_key in req_data["params"]:
raise UniversalParameterOverridden(cursor_key, "_cursor")
req_data["params"][cursor_key] = _cursor

resp: Response = self._session.send( # type: ignore
Request(
method=method,
Expand All @@ -136,11 +163,97 @@ def _fetch(

return resp

def query( # pylint: disable=too-many-locals
self,
table: str,
_auth: Optional[Dict[str, Any]] = None,
_count: Optional[int] = None,
**where: Any,
) -> pd.DataFrame:
"""
Query the API to get a table.
Parameters
----------
table : str
The table name.
_auth : Optional[Dict[str, Any]] = None
The parameters for authentication. Usually the authentication parameters
should be defined when instantiating the Connector. In case some tables have different
authentication options, a different authentication parameter can be defined here.
This parameter will override the one from Connector if passed.
_count: Optional[int] = None
count of returned records.
**where: Any
The additional parameters required for the query.
"""
assert (
table in self._impdb.tables
), f"No such table {table} in {self._impdb.name}"

itable = self._impdb.tables[table]

if itable.pag_params is None:
resp = self._fetch(table=itable, _auth=_auth, kwargs=where)
df = itable.from_response(resp)
return df

# Pagination is not None
max_count = itable.pag_params.max_count
dfs = []
last_id = 0
pag_type = itable.pag_params.type

if _count is None:
# User doesn't specify _count
resp = self._fetch(table=itable, _auth=_auth, kwargs=where)
df = itable.from_response(resp)
else:
cnt_to_fetch = 0
count = _count or 1
n_page = math.ceil(count / max_count)
remain = count % max_count
for i in range(n_page):
remain = remain if remain > 0 else max_count
cnt_to_fetch = max_count if i < n_page - 1 else remain

if pag_type == "cursor":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=last_id - 1,
kwargs=where,
)
elif pag_type == "limit":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=i * max_count,
kwargs=where,
)
else:
raise NotImplementedError
df_ = itable.from_response(resp)

if len(df_) == 0:
# The API returns empty for this page, maybe we've reached the end
break

if pag_type == "cursor":
last_id = int(df_[itable.pag_params.cursor_id][len(df_) - 1]) - 1

dfs.append(df_)

df = pd.concat(dfs, axis=0)
df.reset_index(drop=True, inplace=True)

return df

@property
def table_names(self) -> List[str]:
"""
Return all the names of the available tables in a list.
Note
----
We abstract each website as a database containing several tables.
Expand Down Expand Up @@ -184,17 +297,14 @@ def show_schema(self, table_name: str) -> pd.DataFrame:
"""
This method shows the schema of the table that will be returned,
so that the user knows what information to expect.
Parameters
----------
table_name
The table name.
Returns
-------
pd.DataFrame
The returned data's schema.
Note
----
The schema is defined in the configuration file.
Expand All @@ -210,40 +320,3 @@ def show_schema(self, table_name: str) -> pd.DataFrame:
new_schema_dict["column_name"].append(k)
new_schema_dict["data_type"].append(schema[k]["type"])
return pd.DataFrame.from_dict(new_schema_dict)

def query(
self, table: str, auth_params: Optional[Dict[str, Any]] = None, **where: Any,
) -> pd.DataFrame:
"""
Use this method to query the API and get the returned table.
Example
-------
>>> df = dc.query('businesses', term="korean", location="vancouver)
Parameters
----------
table
The table name.
auth_params
The parameters for authentication. Usually the authentication parameters
should be defined when instantiating the Connector. In case some tables have different
authentication options, a different authentication parameter can be defined here.
This parameter will override the one from Connector if passed.
where
The additional parameters required for the query.
Returns
-------
pd.DataFrame
A DataFrame that contains the data returned by the website API.
"""
assert (
table in self._impdb.tables
), f"No such table {table} in {self._impdb.name}"

itable = self._impdb.tables[table]

resp = self._fetch(itable, auth_params, where)

return itable.from_response(resp)
17 changes: 17 additions & 0 deletions dataprep/data_connector/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ def __init__(self, status_code: int, message: str) -> None:

def __str__(self) -> str:
return f"RequestError: status={self.status_code}, message={self.message}"


class UniversalParameterOverridden(Exception):
"""
The parameter is overrided by the universal parameter
"""

param: str
uparam: str

def __init__(self, param: str, uparam: str) -> None:
super().__init__()
self.param = param
self.uparam = uparam

def __str__(self) -> str:
return f"the parameter {self.param} is overridden by {self.uparam}"
27 changes: 27 additions & 0 deletions dataprep/data_connector/implicit_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ class SchemaField(NamedTuple):
description: Optional[str]


class Pagination:
"""
Schema of Pagination field
"""

type: str
count_key: str
max_count: int
anchor_key: Optional[str]
cursor_id: Optional[str]
cursor_key: Optional[str]

def __init__(self, pdef: Dict[str, Any]) -> None:

self.type = pdef["type"]
self.max_count = pdef["max_count"]
self.count_key = pdef["count_key"]
self.anchor_key = pdef.get("anchor_key")
self.cursor_id = pdef.get("cursor_id")
self.cursor_key = pdef.get("cursor_key")


class ImplicitTable: # pylint: disable=too-many-instance-attributes
"""
ImplicitTable class abstracts the request and the response to a Restful API,
Expand All @@ -54,6 +76,7 @@ class ImplicitTable: # pylint: disable=too-many-instance-attributes
body_ctype: str
body: Optional[Fields] = None
cookies: Optional[Fields] = None
pag_params: Optional[Pagination] = None

# Response related
ctype: str
Expand Down Expand Up @@ -85,9 +108,13 @@ def __init__(self, name: str, config: Dict[str, Any]) -> None:
raise NotImplementedError
self.authorization = Authorization(auth_type=auth_type, params=auth_params)

if "pagination" in request_def:
self.pag_params = Pagination(request_def["pagination"])

for key in ["headers", "params", "cookies"]:
if key in request_def:
setattr(self, key, Fields(request_def[key]))

if "body" in request_def:
body_def = request_def["body"]
self.body_ctype = body_def["ctype"]
Expand Down
33 changes: 33 additions & 0 deletions dataprep/data_connector/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@
"params": {
"$ref": "#/definitions/fields"
},
"pagination": {
"$id": "#/properties/request/properties/pagination",
"type": "object",
"properties": {
"type": {
"type": "string"
},
"max_count": {
"type": "integer"
},
"anchor_key": {
"type": "string",
"optional": true
},
"count_key": {
"type": "string"
},
"cursor_id": {
"type": "string",
"optional": true
},
"cursor_key": {
"type": "string",
"optional": true
}
},
"required": [
"count_key",
"type",
"max_count"
],
"additionalProperties": false
},
"body": {
"$id": "#/properties/request/properties/body",
"type": "object",
Expand Down
Loading

0 comments on commit 3d2e8b7

Please sign in to comment.