diff --git a/dataprep/data_connector/connector.py b/dataprep/data_connector/connector.py index 269e8f1c8..d5cffb44d 100644 --- a/dataprep/data_connector/connector.py +++ b/dataprep/data_connector/connector.py @@ -2,7 +2,7 @@ This module contains the Connector class. Every data fetching action should begin with instantiating this Connector class. """ - +import math from pathlib import Path from typing import Any, Dict, List, Optional @@ -12,10 +12,9 @@ from ..errors import UnreachableError from .config_manager import config_directory, ensure_config -from .errors import RequestError +from .errors import RequestError, UniversalParameterOverridden from .implicit_database import ImplicitDatabase, ImplicitTable - INFO_TEMPLATE = Template( """{% for tb in tbs.keys() %} Table {{dbname}}.{{tb}} @@ -57,15 +56,15 @@ class Connector: _impdb: ImplicitDatabase _vars: Dict[str, Any] - _auth_params: Dict[str, Any] + _auth: Dict[str, Any] _session: Session _jenv: Environment def __init__( self, config_path: str, - auth_params: Optional[Dict[str, Any]] = None, - **kwargs: Any, + _auth: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], ) -> None: self._session = Session() if ( @@ -82,15 +81,22 @@ def __init__( self._impdb = ImplicitDatabase(path) self._vars = kwargs - self._auth_params = auth_params or {} + self._auth = _auth or {} self._jenv = Environment(undefined=StrictUndefined) - def _fetch( + def _fetch( # pylint: disable=too-many-locals,too-many-branches self, table: ImplicitTable, - auth_params: Optional[Dict[str, Any]], + *, + _count: Optional[int] = None, + _cursor: Optional[int] = None, + _auth: Optional[Dict[str, Any]] = None, kwargs: Dict[str, Any], ) -> Response: + assert (_count is None) == ( + _cursor is None + ), "_cursor and _count should both be None or not None" + method = table.method url = table.url req_data: Dict[str, Dict[str, Any]] = { @@ -98,10 +104,10 @@ def _fetch( "params": {}, "cookies": {}, } - merged_vars = {**self._vars, **kwargs} + if table.authorization is not None: - table.authorization.build(req_data, auth_params or self._auth_params) + table.authorization.build(req_data, _auth or self._auth) for key in ["headers", "params", "cookies"]: if getattr(table, key) is not None: @@ -109,6 +115,7 @@ def _fetch( self._jenv, merged_vars ) req_data[key].update(**instantiated_fields) + if table.body is not None: # TODO: do we support binary body? instantiated_fields = table.body.populate(self._jenv, merged_vars) @@ -119,6 +126,26 @@ def _fetch( else: raise UnreachableError + if table.pag_params is not None and _count is not None: + pag_type = table.pag_params.type + count_key = table.pag_params.count_key + if pag_type == "cursor": + assert table.pag_params.cursor_key is not None + cursor_key = table.pag_params.cursor_key + elif pag_type == "limit": + assert table.pag_params.anchor_key is not None + cursor_key = table.pag_params.anchor_key + else: + raise UnreachableError() + + if count_key in req_data["params"]: + raise UniversalParameterOverridden(count_key, "_count") + req_data["params"][count_key] = _count + + if cursor_key in req_data["params"]: + raise UniversalParameterOverridden(cursor_key, "_cursor") + req_data["params"][cursor_key] = _cursor + resp: Response = self._session.send( # type: ignore Request( method=method, @@ -136,11 +163,97 @@ def _fetch( return resp + def query( # pylint: disable=too-many-locals + self, + table: str, + _auth: Optional[Dict[str, Any]] = None, + _count: Optional[int] = None, + **where: Any, + ) -> pd.DataFrame: + """ + Query the API to get a table. + Parameters + ---------- + table : str + The table name. + _auth : Optional[Dict[str, Any]] = None + The parameters for authentication. Usually the authentication parameters + should be defined when instantiating the Connector. In case some tables have different + authentication options, a different authentication parameter can be defined here. + This parameter will override the one from Connector if passed. + _count: Optional[int] = None + count of returned records. + **where: Any + The additional parameters required for the query. + """ + assert ( + table in self._impdb.tables + ), f"No such table {table} in {self._impdb.name}" + + itable = self._impdb.tables[table] + + if itable.pag_params is None: + resp = self._fetch(table=itable, _auth=_auth, kwargs=where) + df = itable.from_response(resp) + return df + + # Pagination is not None + max_count = itable.pag_params.max_count + dfs = [] + last_id = 0 + pag_type = itable.pag_params.type + + if _count is None: + # User doesn't specify _count + resp = self._fetch(table=itable, _auth=_auth, kwargs=where) + df = itable.from_response(resp) + else: + cnt_to_fetch = 0 + count = _count or 1 + n_page = math.ceil(count / max_count) + remain = count % max_count + for i in range(n_page): + remain = remain if remain > 0 else max_count + cnt_to_fetch = max_count if i < n_page - 1 else remain + + if pag_type == "cursor": + resp = self._fetch( + table=itable, + _auth=_auth, + _count=cnt_to_fetch, + _cursor=last_id - 1, + kwargs=where, + ) + elif pag_type == "limit": + resp = self._fetch( + table=itable, + _auth=_auth, + _count=cnt_to_fetch, + _cursor=i * max_count, + kwargs=where, + ) + else: + raise NotImplementedError + df_ = itable.from_response(resp) + + if len(df_) == 0: + # The API returns empty for this page, maybe we've reached the end + break + + if pag_type == "cursor": + last_id = int(df_[itable.pag_params.cursor_id][len(df_) - 1]) - 1 + + dfs.append(df_) + + df = pd.concat(dfs, axis=0) + df.reset_index(drop=True, inplace=True) + + return df + @property def table_names(self) -> List[str]: """ Return all the names of the available tables in a list. - Note ---- We abstract each website as a database containing several tables. @@ -184,17 +297,14 @@ def show_schema(self, table_name: str) -> pd.DataFrame: """ This method shows the schema of the table that will be returned, so that the user knows what information to expect. - Parameters ---------- table_name The table name. - Returns ------- pd.DataFrame The returned data's schema. - Note ---- The schema is defined in the configuration file. @@ -210,40 +320,3 @@ def show_schema(self, table_name: str) -> pd.DataFrame: new_schema_dict["column_name"].append(k) new_schema_dict["data_type"].append(schema[k]["type"]) return pd.DataFrame.from_dict(new_schema_dict) - - def query( - self, table: str, auth_params: Optional[Dict[str, Any]] = None, **where: Any, - ) -> pd.DataFrame: - """ - Use this method to query the API and get the returned table. - - Example - ------- - >>> df = dc.query('businesses', term="korean", location="vancouver) - - Parameters - ---------- - table - The table name. - auth_params - The parameters for authentication. Usually the authentication parameters - should be defined when instantiating the Connector. In case some tables have different - authentication options, a different authentication parameter can be defined here. - This parameter will override the one from Connector if passed. - where - The additional parameters required for the query. - - Returns - ------- - pd.DataFrame - A DataFrame that contains the data returned by the website API. - """ - assert ( - table in self._impdb.tables - ), f"No such table {table} in {self._impdb.name}" - - itable = self._impdb.tables[table] - - resp = self._fetch(itable, auth_params, where) - - return itable.from_response(resp) diff --git a/dataprep/data_connector/errors.py b/dataprep/data_connector/errors.py index 408a1089d..c773ef573 100644 --- a/dataprep/data_connector/errors.py +++ b/dataprep/data_connector/errors.py @@ -32,3 +32,20 @@ def __init__(self, status_code: int, message: str) -> None: def __str__(self) -> str: return f"RequestError: status={self.status_code}, message={self.message}" + + +class UniversalParameterOverridden(Exception): + """ + The parameter is overrided by the universal parameter + """ + + param: str + uparam: str + + def __init__(self, param: str, uparam: str) -> None: + super().__init__() + self.param = param + self.uparam = uparam + + def __str__(self) -> str: + return f"the parameter {self.param} is overridden by {self.uparam}" diff --git a/dataprep/data_connector/implicit_database.py b/dataprep/data_connector/implicit_database.py index 049787f36..347db4b7c 100644 --- a/dataprep/data_connector/implicit_database.py +++ b/dataprep/data_connector/implicit_database.py @@ -37,6 +37,28 @@ class SchemaField(NamedTuple): description: Optional[str] +class Pagination: + """ + Schema of Pagination field + """ + + type: str + count_key: str + max_count: int + anchor_key: Optional[str] + cursor_id: Optional[str] + cursor_key: Optional[str] + + def __init__(self, pdef: Dict[str, Any]) -> None: + + self.type = pdef["type"] + self.max_count = pdef["max_count"] + self.count_key = pdef["count_key"] + self.anchor_key = pdef.get("anchor_key") + self.cursor_id = pdef.get("cursor_id") + self.cursor_key = pdef.get("cursor_key") + + class ImplicitTable: # pylint: disable=too-many-instance-attributes """ ImplicitTable class abstracts the request and the response to a Restful API, @@ -54,6 +76,7 @@ class ImplicitTable: # pylint: disable=too-many-instance-attributes body_ctype: str body: Optional[Fields] = None cookies: Optional[Fields] = None + pag_params: Optional[Pagination] = None # Response related ctype: str @@ -85,9 +108,13 @@ def __init__(self, name: str, config: Dict[str, Any]) -> None: raise NotImplementedError self.authorization = Authorization(auth_type=auth_type, params=auth_params) + if "pagination" in request_def: + self.pag_params = Pagination(request_def["pagination"]) + for key in ["headers", "params", "cookies"]: if key in request_def: setattr(self, key, Fields(request_def[key])) + if "body" in request_def: body_def = request_def["body"] self.body_ctype = body_def["ctype"] diff --git a/dataprep/data_connector/schema.json b/dataprep/data_connector/schema.json index 27368ba67..4ce149532 100644 --- a/dataprep/data_connector/schema.json +++ b/dataprep/data_connector/schema.json @@ -61,6 +61,39 @@ "params": { "$ref": "#/definitions/fields" }, + "pagination": { + "$id": "#/properties/request/properties/pagination", + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "max_count": { + "type": "integer" + }, + "anchor_key": { + "type": "string", + "optional": true + }, + "count_key": { + "type": "string" + }, + "cursor_id": { + "type": "string", + "optional": true + }, + "cursor_key": { + "type": "string", + "optional": true + } + }, + "required": [ + "count_key", + "type", + "max_count" + ], + "additionalProperties": false + }, "body": { "$id": "#/properties/request/properties/body", "type": "object", diff --git a/dataprep/tests/data_connector/test_integration.py b/dataprep/tests/data_connector/test_integration.py index dc0955fc5..b7aff83ef 100644 --- a/dataprep/tests/data_connector/test_integration.py +++ b/dataprep/tests/data_connector/test_integration.py @@ -4,7 +4,7 @@ def test_data_connector() -> None: token = environ["DATAPREP_DATA_CONNECTOR_YELP_TOKEN"] - dc = Connector("yelp", auth_params={"access_token": token}) + dc = Connector("yelp", _auth={"access_token": token}) df = dc.query("businesses", term="ramen", location="vancouver") assert len(df) > 0 @@ -14,3 +14,11 @@ def test_data_connector() -> None: schema = dc.show_schema("businesses") assert len(schema) > 0 + + df = dc.query("businesses", _count=120, term="ramen", location="vancouver") + + assert len(df) == 120 + + df = dc.query("businesses", _count=10000, term="ramen", location="vancouver") + + assert len(df) < 1000