Skip to content

Commit

Permalink
refactor: refactor pagination implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dovahcrow committed Jun 1, 2020
1 parent e729213 commit b1ced65
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 121 deletions.
159 changes: 79 additions & 80 deletions dataprep/data_connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

from ..errors import UnreachableError
from .config_manager import config_directory, ensure_config
from .errors import RequestError
from .errors import RequestError, UniversalParameterOverridden
from .implicit_database import ImplicitDatabase, ImplicitTable


INFO_TEMPLATE = Template(
"""{% for tb in tbs.keys() %}
Table {{dbname}}.{{tb}}
Expand Down Expand Up @@ -85,14 +84,19 @@ def __init__(
self._auth = _auth or {}
self._jenv = Environment(undefined=StrictUndefined)

def _fetch(
def _fetch( # pylint: disable=too-many-locals,too-many-branches
self,
table: ImplicitTable,
_auth: Optional[Dict[str, Any]],
_count: Optional[int],
_cursor: Optional[int],
*,
_count: Optional[int] = None,
_cursor: Optional[int] = None,
_auth: Optional[Dict[str, Any]] = None,
kwargs: Dict[str, Any],
) -> Response:
assert (_count is None) == (
_cursor is None
), "_cursor and _count should both be None or not None"

method = table.method
url = table.url
req_data: Dict[str, Dict[str, Any]] = {
Expand All @@ -104,6 +108,7 @@ def _fetch(

if table.authorization is not None:
table.authorization.build(req_data, _auth or self._auth)

for key in ["headers", "params", "cookies"]:
if getattr(table, key) is not None:
instantiated_fields = getattr(table, key).populate(
Expand All @@ -121,29 +126,24 @@ def _fetch(
else:
raise UnreachableError

if table.pag_params is not None:
if table.pag_params.type == "null":
resp: Response = self._session.send( # type: ignore
Request(
method=method,
url=url,
headers=req_data["headers"],
params=req_data["params"],
json=req_data.get("json"),
data=req_data.get("data"),
cookies=req_data["cookies"],
).prepare()
)
return resp

if table.pag_params is not None and _count is not None:
pag_type = table.pag_params.type
count_key = table.pag_params.count_key
cursor_key = ""
if pag_type == "cursor":
assert table.pag_params.cursor_key is not None
cursor_key = table.pag_params.cursor_key
if pag_type == "limit":
elif pag_type == "limit":
assert table.pag_params.anchor_key is not None
cursor_key = table.pag_params.anchor_key
else:
raise UnreachableError()

if count_key in req_data["params"]:
raise UniversalParameterOverridden(count_key, "_count")
req_data["params"][count_key] = _count

if cursor_key in req_data["params"]:
raise UniversalParameterOverridden(cursor_key, "_cursor")
req_data["params"][cursor_key] = _cursor

resp: Response = self._session.send( # type: ignore
Expand All @@ -163,7 +163,7 @@ def _fetch(

return resp

def query(
def query( # pylint: disable=too-many-locals
self,
table: str,
_auth: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -192,62 +192,61 @@ def query(

itable = self._impdb.tables[table]

if itable.pag_params is not None:
if itable.pag_params.type == "null":
resp = self._fetch(
table=itable, _auth=_auth, _count=-1, _cursor=-1, kwargs=where
)
df = itable.from_response(resp)
return df

max_count = int(itable.pag_params.max_count)
df = pd.DataFrame()
last_id = 0
pag_type = itable.pag_params.type

if _count is None:
_count = max_count
resp = self._fetch(
table=itable, _auth=_auth, _count=_count, _cursor=0, kwargs=where
)
df = itable.from_response(resp)

else:
cnt_to_fetch = 0
count = _count or 1
n_page = math.ceil(count / max_count)
remain = count % max_count
for i in range(n_page):
remain = remain if remain > 0 else max_count
cnt_to_fetch = max_count if i < n_page - 1 else remain
if pag_type == "cursor":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=last_id - 1,
kwargs=where,
)
elif pag_type == "limit":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=i * max_count,
kwargs=where,
)
else:
raise NotImplementedError
df_ = itable.from_response(resp)
if pag_type == "cursor":
last_id = (
int(df_[itable.pag_params.cursor_id][len(df_) - 1]) - 1
)
if i == 0:
df = df_.copy()
else:
df = pd.concat([df, df_], axis=0)
df.reset_index(drop=True, inplace=True)
if itable.pag_params is None:
resp = self._fetch(table=itable, _auth=_auth, kwargs=where)
df = itable.from_response(resp)
return df

# Pagination is not None
max_count = itable.pag_params.max_count
dfs = []
last_id = 0
pag_type = itable.pag_params.type

if _count is None:
# User doesn't specify _count
resp = self._fetch(table=itable, _auth=_auth, kwargs=where)
df = itable.from_response(resp)
else:
cnt_to_fetch = 0
count = _count or 1
n_page = math.ceil(count / max_count)
remain = count % max_count
for i in range(n_page):
remain = remain if remain > 0 else max_count
cnt_to_fetch = max_count if i < n_page - 1 else remain

if pag_type == "cursor":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=last_id - 1,
kwargs=where,
)
elif pag_type == "limit":
resp = self._fetch(
table=itable,
_auth=_auth,
_count=cnt_to_fetch,
_cursor=i * max_count,
kwargs=where,
)
else:
raise NotImplementedError
df_ = itable.from_response(resp)

if len(df_) == 0:
# The API returns empty for this page, maybe we've reached the end
break

if pag_type == "cursor":
last_id = int(df_[itable.pag_params.cursor_id][len(df_) - 1]) - 1

dfs.append(df_)

df = pd.concat(dfs, axis=0)
df.reset_index(drop=True, inplace=True)

return df

Expand Down
17 changes: 17 additions & 0 deletions dataprep/data_connector/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ def __init__(self, status_code: int, message: str) -> None:

def __str__(self) -> str:
return f"RequestError: status={self.status_code}, message={self.message}"


class UniversalParameterOverridden(Exception):
"""
The parameter is overrided by the universal parameter
"""

param: str
uparam: str

def __init__(self, param: str, uparam: str) -> None:
super().__init__()
self.param = param
self.uparam = uparam

def __str__(self) -> str:
return f"the parameter {self.param} is overridden by {self.uparam}"
49 changes: 18 additions & 31 deletions dataprep/data_connector/implicit_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@ class Pagination:
"""

type: str
anchor_key: str
count_key: str
cursor_id: str
cursor_key: str
max_count: str
max_count: int
anchor_key: Optional[str]
cursor_id: Optional[str]
cursor_key: Optional[str]

def __init__(self, pdef: Dict[str, Any]) -> None:

self.type = pdef["type"]
self.max_count = pdef["max_count"]
self.count_key = pdef["count_key"]
self.anchor_key = pdef.get("anchor_key")
self.cursor_id = pdef.get("cursor_id")
self.cursor_key = pdef.get("cursor_key")


class ImplicitTable: # pylint: disable=too-many-instance-attributes
Expand All @@ -67,7 +76,7 @@ class ImplicitTable: # pylint: disable=too-many-instance-attributes
body_ctype: str
body: Optional[Fields] = None
cookies: Optional[Fields] = None
pag_params: Optional[Pagination]
pag_params: Optional[Pagination] = None

# Response related
ctype: str
Expand All @@ -81,6 +90,7 @@ def __init__(self, name: str, config: Dict[str, Any]) -> None:
) # This will throw errors if validate failed
self.name = name
self.config = config

request_def = config["request"]

self.method = request_def["method"]
Expand All @@ -99,36 +109,12 @@ def __init__(self, name: str, config: Dict[str, Any]) -> None:
self.authorization = Authorization(auth_type=auth_type, params=auth_params)

if "pagination" in request_def:
self.pag_params = Pagination()
self.pag_params.type = request_def["pagination"]["type"]
self.pag_params.max_count = request_def["pagination"]["max_count"]
self.pag_params.count_key = (
request_def["pagination"]["count_key"]
if "count_key" in request_def["pagination"]
else ""
)
self.pag_params.anchor_key = (
request_def["pagination"]["anchor_key"]
if "anchor_key" in request_def["pagination"]
else ""
)
self.pag_params.cursor_id = (
request_def["pagination"]["cursor_id"]
if "cursor_id" in request_def["pagination"]
else ""
)
self.pag_params.cursor_key = (
request_def["pagination"]["cursor_key"]
if "cursor_key" in request_def["pagination"]
else ""
)
else:
self.pag_params = Pagination()
self.pag_params.type = "null"
self.pag_params = Pagination(request_def["pagination"])

for key in ["headers", "params", "cookies"]:
if key in request_def:
setattr(self, key, Fields(request_def[key]))

if "body" in request_def:
body_def = request_def["body"]
self.body_ctype = body_def["ctype"]
Expand Down Expand Up @@ -271,6 +257,7 @@ def __init__(self, config_path: Union[str, Path]) -> None:
if table_config_path.suffix != ".json":
# ifnote non json file
continue

with open(table_config_path) as f:
table_config = jload(f)

Expand Down
22 changes: 13 additions & 9 deletions dataprep/data_connector/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,32 @@
},
"pagination": {
"$id": "#/properties/request/properties/pagination",
"type":"object",
"type": "object",
"properties": {
"type": { "type": "string" },
"max_count": { "type": "string" },
"anchor_key":{
"type": "string",
"optional": true
"type": {
"type": "string"
},
"max_count": {
"type": "integer"
},
"count_key":{
"anchor_key": {
"type": "string",
"optional": true
},
"cursor_id":{
"count_key": {
"type": "string"
},
"cursor_id": {
"type": "string",
"optional": true
},
"cursor_key":{
"cursor_key": {
"type": "string",
"optional": true
}
},
"required": [
"count_key",
"type",
"max_count"
],
Expand Down
6 changes: 5 additions & 1 deletion dataprep/tests/data_connector/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@ def test_data_connector() -> None:

df = dc.query("businesses", _count=120, term="ramen", location="vancouver")

assert len(df) > 0
assert len(df) == 120

df = dc.query("businesses", _count=10000, term="ramen", location="vancouver")

assert len(df) < 1000

0 comments on commit b1ced65

Please sign in to comment.