diff --git a/clkhash/__init__.py b/clkhash/__init__.py index a403ed0f..c79d2efb 100644 --- a/clkhash/__init__.py +++ b/clkhash/__init__.py @@ -7,8 +7,8 @@ from . import randomnames try: - __version__ = pkg_resources.get_distribution('clkhash').version + __version__ = pkg_resources.get_distribution("clkhash").version except pkg_resources.DistributionNotFound: __version__ = "development" -__author__ = 'N1 Analytics' +__author__ = "N1 Analytics" diff --git a/clkhash/backports.py b/clkhash/backports.py index fb673ecb..603655b8 100644 --- a/clkhash/backports.py +++ b/clkhash/backports.py @@ -23,27 +23,33 @@ def __int_from_bytes(bytes, byteorder, signed=False): """ if signed: raise NotImplementedError( - "Signed integers are not currently supported in this " - "backport.") + "Signed integers are not currently supported in this " "backport." + ) - if byteorder == 'big': + if byteorder == "big": pass - elif byteorder == 'little': + elif byteorder == "little": bytes = bytes[::-1] else: raise ValueError("byteorder must be either 'little' or 'big'") - hex_str = codecs.encode(bytes, 'hex') # type: ignore + hex_str = codecs.encode(bytes, "hex") # type: ignore return int(hex_str, 16) # Make this cast since Python 2 doesn't have syntax for default # named arguments. Hence, must cast so Mypy thinks it matches the # original function. - int_from_bytes = cast(Callable[[Arg(Sequence[int], 'bytes'), - Arg(str, 'byteorder'), - DefaultNamedArg(bool, 'signed')], - int], - __int_from_bytes) + int_from_bytes = cast( + Callable[ + [ + Arg(Sequence[int], "bytes"), + Arg(str, "byteorder"), + DefaultNamedArg(bool, "signed"), + ], + int, + ], + __int_from_bytes, + ) def re_compile_full(pattern, flags=0): @@ -65,11 +71,11 @@ def re_compile_full(pattern, flags=0): # A pattern of type bytes doesn't make sense in Python 3. assert type(pattern) is not bytes or str is bytes - return re.compile('(?:{})\Z'.format(pattern), flags=flags) + return re.compile("(?:{})\Z".format(pattern), flags=flags) def _utf_8_encoder(unicode_csv_data): - return (line.encode('utf-8') for line in unicode_csv_data) + return (line.encode("utf-8") for line in unicode_csv_data) def _p2_unicode_reader(unicode_csv_data, dialect=csv.excel, **kwargs): @@ -92,9 +98,10 @@ def _p2_unicode_reader(unicode_csv_data, dialect=csv.excel, **kwargs): csv_reader = csv.reader(utf8_csv_data, dialect=dialect, **kwargs) # Decode UTF-8 back to Unicode, cell by cell: - return ([unicode(cell, 'utf-8') for cell in row] for row in csv_reader) + return ([unicode(cell, "utf-8") for cell in row] for row in csv_reader) -unicode_reader = (_p2_unicode_reader # Python 2 with hacky workarounds. - if sys.version_info < (3,0) - else csv.reader) # Py3 with native Unicode support. +unicode_reader = ( + _p2_unicode_reader if sys.version_info + < (3, 0) else csv.reader # Python 2 with hacky workarounds. +) # Py3 with native Unicode support. diff --git a/clkhash/benchmark.py b/clkhash/benchmark.py index 4403b458..701a73c0 100644 --- a/clkhash/benchmark.py +++ b/clkhash/benchmark.py @@ -17,17 +17,17 @@ def compute_hash_speed(n, quiet=False): os_fd, tmpfile_name = tempfile.mkstemp(text=True) schema = NameList.SCHEMA - header_row = ','.join([f.identifier for f in schema.fields]) + header_row = ",".join([f.identifier for f in schema.fields]) - with open(tmpfile_name, 'wt') as f: + with open(tmpfile_name, "wt") as f: f.write(header_row) - f.write('\n') + f.write("\n") for person in namelist.names: - print(','.join([str(field) for field in person]), file=f) + print(",".join([str(field) for field in person]), file=f) - with open(tmpfile_name, 'rt') as f: + with open(tmpfile_name, "rt") as f: start = timer() - generate_clk_from_csv(f, ('key1', 'key2'), schema, progress_bar=not quiet) + generate_clk_from_csv(f, ("key1", "key2"), schema, progress_bar=not quiet) end = timer() os.close(os_fd) @@ -35,10 +35,14 @@ def compute_hash_speed(n, quiet=False): elapsed_time = end - start if not quiet: - print("{:6d} hashes in {:.6f} seconds. {:.2f} KH/s".format(n, elapsed_time, n/(1000*elapsed_time))) + print( + "{:6d} hashes in {:.6f} seconds. {:.2f} KH/s".format( + n, elapsed_time, n / (1000 * elapsed_time) + ) + ) return n / elapsed_time -if __name__ == '__main__': +if __name__ == "__main__": for n in [100, 1000, 10000, 50000, 100000]: - compute_hash_speed(n, quiet=n<=10000) + compute_hash_speed(n, quiet=n <= 10000) diff --git a/clkhash/bloomfilter.py b/clkhash/bloomfilter.py index b88f9cdb..58493d97 100644 --- a/clkhash/bloomfilter.py +++ b/clkhash/bloomfilter.py @@ -26,16 +26,17 @@ except ImportError: # We are in Python older than 3.6. from pyblake2 import blake2b # type: ignore - # Ignore because otherwise Mypy raises errors, thinking that - # blake2b is already defined. +# Ignore because otherwise Mypy raises errors, thinking that +# blake2b is already defined. -def double_hash_encode_ngrams(ngrams, # type: Iterable[str] - keys, # type: Sequence[bytes] - k, # type: int - l, # type: int - encoding # type: str - ): +def double_hash_encode_ngrams( + ngrams, + keys, + k, + l, + encoding, # type: Iterable[str] # type: Sequence[bytes] # type: int # type: int # type: str +): # type: (...) -> bitarray """ Computes the double hash encoding of the provided ngrams with the given keys. @@ -55,20 +56,25 @@ def double_hash_encode_ngrams(ngrams, # type: Iterable[str] bf = bitarray(l) bf.setall(False) for m in ngrams: - sha1hm = int(hmac.new(key_sha1, m.encode(encoding=encoding), sha1).hexdigest(), 16) % l - md5hm = int(hmac.new(key_md5, m.encode(encoding=encoding), md5).hexdigest(), 16) % l + sha1hm = int( + hmac.new(key_sha1, m.encode(encoding=encoding), sha1).hexdigest(), 16 + ) % l + md5hm = int( + hmac.new(key_md5, m.encode(encoding=encoding), md5).hexdigest(), 16 + ) % l for i in range(k): gi = (sha1hm + i * md5hm) % l bf[gi] = 1 return bf -def double_hash_encode_ngrams_non_singular(ngrams, # type: Iterable[str] - keys, # type: Sequence[bytes] - k, # type: int - l, # type: int - encoding # type: str - ): +def double_hash_encode_ngrams_non_singular( + ngrams, + keys, + k, + l, + encoding, # type: Iterable[str] # type: Sequence[bytes] # type: int # type: int # type: str +): # type: (...) -> bitarray.bitarray """ computes the double hash encoding of the provided n-grams with the given keys. @@ -114,14 +120,13 @@ def double_hash_encode_ngrams_non_singular(ngrams, # type: Iterable[str sha1hm_bytes = hmac.new(key_sha1, m_bytes, sha1).digest() md5hm_bytes = hmac.new(key_md5, m_bytes, md5).digest() - sha1hm = int_from_bytes(sha1hm_bytes, 'big') % l - md5hm = int_from_bytes(md5hm_bytes, 'big') % l + sha1hm = int_from_bytes(sha1hm_bytes, "big") % l + md5hm = int_from_bytes(md5hm_bytes, "big") % l i = 0 while md5hm == 0: - md5hm_bytes = hmac.new( - key_md5, m_bytes + chr(i).encode(), md5).digest() - md5hm = int_from_bytes(md5hm_bytes, 'big') % l + md5hm_bytes = hmac.new(key_md5, m_bytes + chr(i).encode(), md5).digest() + md5hm = int_from_bytes(md5hm_bytes, "big") % l i += 1 for i in range(k): @@ -130,12 +135,13 @@ def double_hash_encode_ngrams_non_singular(ngrams, # type: Iterable[str return bf -def blake_encode_ngrams(ngrams, # type: Iterable[str] - keys, # type: Sequence[bytes] - k, # type: int - l, # type: int - encoding # type: str - ): +def blake_encode_ngrams( + ngrams, + keys, + k, + l, + encoding, # type: Iterable[str] # type: Sequence[bytes] # type: int # type: int # type: str +): # type: (...) -> bitarray.bitarray """ Computes the encoding of the provided ngrams using the BLAKE2 hash function. @@ -189,19 +195,29 @@ def blake_encode_ngrams(ngrams, # type: Iterable[str] key, = keys # Unpack. log_l = int(math.log(l, 2)) - if not 2**log_l == l: - raise ValueError('parameter "l" has to be a power of two for the BLAKE2 encoding, but was: {}'.format(l)) + if not 2 ** log_l == l: + raise ValueError( + 'parameter "l" has to be a power of two for the BLAKE2 encoding, but was: {}'.format( + l + ) + ) + bf = bitarray(l) bf.setall(False) if k < 1: return bf - num_macs = (k+31) // 32 + + num_macs = (k + 31) // 32 for m in ngrams: random_shorts = [] # type: List[int] for i in range(num_macs): - hash_bytes = blake2b(m.encode(encoding=encoding), key=key, salt=str(i).encode()).digest() - random_shorts.extend(struct.unpack('32H', hash_bytes)) # interpret hash bytes as 32 unsigned shorts. + hash_bytes = blake2b( + m.encode(encoding=encoding), key=key, salt=str(i).encode() + ).digest() + random_shorts.extend( + struct.unpack("32H", hash_bytes) + ) # interpret hash bytes as 32 unsigned shorts. for i in range(k): idx = random_shorts[i] % l bf[idx] = 1 @@ -231,25 +247,24 @@ def __call__(self, *args): return self.value(*args) @classmethod - def from_properties(cls, - properties # type: GlobalHashingProperties - ): + def from_properties(cls, properties): # type: GlobalHashingProperties # type: (...) -> Callable[[Iterable[str], Sequence[bytes], int, int, str], bitarray] - if properties.hash_type == 'doubleHash': + if properties.hash_type == "doubleHash": if properties.hash_prevent_singularity: return cls.DOUBLE_HASH_NON_SINGULAR + else: return cls.DOUBLE_HASH - elif properties.hash_type == 'blakeHash': + + elif properties.hash_type == "blakeHash": return cls.BLAKE_HASH + else: msg = "Unsupported hash type '{}'".format(properties.hash_type) raise ValueError(msg) -def fold_xor(bloomfilter, # type: bitarray - folds # type: int - ): +def fold_xor(bloomfilter, folds): # type: bitarray # type: int # type: (...) -> bitarray """ Performs XOR folding on a Bloom filter. @@ -263,10 +278,11 @@ def fold_xor(bloomfilter, # type: bitarray """ if len(bloomfilter) % 2 ** folds != 0: - msg = ('The length of the bloom filter is {length}. It is not ' - 'divisible by 2 ** {folds}, so it cannot be folded {folds} ' - 'times.' - .format(length=len(bloomfilter), folds=folds)) + msg = ( + "The length of the bloom filter is {length}. It is not " + "divisible by 2 ** {folds}, so it cannot be folded {folds} " + "times.".format(length=len(bloomfilter), folds=folds) + ) raise ValueError(msg) for _ in range(folds): @@ -278,12 +294,13 @@ def fold_xor(bloomfilter, # type: bitarray return bloomfilter -def crypto_bloom_filter(record, # type: Sequence[Text] - tokenizers, # type: List[Callable[[Text], Iterable[Text]]] - field_hashing, # type: List[FieldHashingProperties] - keys, # type: Sequence[Sequence[bytes]] - hash_properties # type: GlobalHashingProperties - ): +def crypto_bloom_filter( + record, # type: Sequence[Text] + tokenizers, # type: List[Callable[[Text], Iterable[Text]]] + field_hashing, # type: List[FieldHashingProperties] + keys, # type: Sequence[Sequence[bytes]] + hash_properties, # type: GlobalHashingProperties +): # type: (...) -> Tuple[bitarray, Text, int] """ Makes a Bloom filter from a record with given tokenizers and lists of keys. @@ -311,23 +328,22 @@ def crypto_bloom_filter(record, # type: Sequence[Text] bloomfilter = bitarray(l) bloomfilter.setall(False) - for (entry, tokenizer, field, key) \ - in zip(record, tokenizers, field_hashing, keys): + for (entry, tokenizer, field, key) in zip(record, tokenizers, field_hashing, keys): ngrams = tokenizer(entry) adjusted_k = int(round(field.weight * k)) - bloomfilter |= hash_function( - ngrams, key, adjusted_k, l, field.encoding) + bloomfilter |= hash_function(ngrams, key, adjusted_k, l, field.encoding) bloomfilter = fold_xor(bloomfilter, xor_folds) return bloomfilter, record[0], bloomfilter.count() -def stream_bloom_filters(dataset, # type: Iterable[Sequence[Text]] - keys, # type: Sequence[Sequence[bytes]] - schema # type: Schema - ): +def stream_bloom_filters( + dataset, + keys, + schema, # type: Iterable[Sequence[Text]] # type: Sequence[Sequence[bytes]] # type: Schema +): # type: (...) -> Iterable[Tuple[bitarray, Text, int]] """ Yield bloom filters @@ -338,14 +354,16 @@ def stream_bloom_filters(dataset, # type: Iterable[Sequence[Text]] :param xor_folds: number of XOR folds to perform :return: Yields bloom filters as 3-tuples """ - tokenizers = [tokenizer.get_tokenizer(field.hashing_properties) - for field in schema.fields] + tokenizers = [ + tokenizer.get_tokenizer(field.hashing_properties) for field in schema.fields + ] field_hashing = [field.hashing_properties for field in schema.fields] hash_properties = schema.hashing_globals - return (crypto_bloom_filter(s, tokenizers, field_hashing, - keys, hash_properties) - for s in dataset) + return ( + crypto_bloom_filter(s, tokenizers, field_hashing, keys, hash_properties) + for s in dataset + ) def serialize_bitarray(ba): @@ -353,4 +371,4 @@ def serialize_bitarray(ba): """Serialize a bitarray (bloomfilter) """ - return base64.b64encode(ba.tobytes()).decode('utf8') + return base64.b64encode(ba.tobytes()).decode("utf8") diff --git a/clkhash/cli.py b/clkhash/cli.py index 9a553ba9..68a11643 100644 --- a/clkhash/cli.py +++ b/clkhash/cli.py @@ -13,17 +13,16 @@ from clkhash import benchmark as bench, clk, randomnames -DEFAULT_SERVICE_URL = 'https://es.data61.xyz' +DEFAULT_SERVICE_URL = "https://es.data61.xyz" -def log(m, color='red'): +def log(m, color="red"): click.echo(click.style(m, fg=color), err=True) @click.group("clkutil") @click.version_option(clkhash.__version__) -@click.option('--verbose', '-v', is_flag=True, - help='Enables verbose mode.') +@click.option("--verbose", "-v", is_flag=True, help="Enables verbose mode.") def cli(verbose=False): """ This command line application allows a user to hash their @@ -43,14 +42,17 @@ def cli(verbose=False): """ - -@cli.command('hash', short_help="generate hashes from local PII data") -@click.argument('input', type=click.File('r')) -@click.argument('keys', nargs=2, type=click.Tuple([str, str])) -@click.argument('schema', type=click.File('r', lazy=True)) -@click.argument('output', type=click.File('w')) -@click.option('-q', '--quiet', default=False, is_flag=True, help="Quiet any progress messaging") -@click.option('--no-header', default=False, is_flag=True, help="Don't skip the first row") +@cli.command("hash", short_help="generate hashes from local PII data") +@click.argument("input", type=click.File("r")) +@click.argument("keys", nargs=2, type=click.Tuple([str, str])) +@click.argument("schema", type=click.File("r", lazy=True)) +@click.argument("output", type=click.File("w")) +@click.option( + "-q", "--quiet", default=False, is_flag=True, help="Quiet any progress messaging" +) +@click.option( + "--no-header", default=False, is_flag=True, help="Don't skip the first row" +) def hash(input, keys, schema, output, quiet, no_header): """Process data to create CLKs @@ -70,17 +72,24 @@ def hash(input, keys, schema, output, quiet, no_header): schema_object = clkhash.schema.Schema.from_json_file(schema_file=schema) clk_data = clk.generate_clk_from_csv( - input, keys, schema_object, - header=not no_header, progress_bar=not quiet) - json.dump({'clks': clk_data}, output) - if hasattr(output, 'name'): + input, keys, schema_object, header=not no_header, progress_bar=not quiet + ) + json.dump({"clks": clk_data}, output) + if hasattr(output, "name"): log("CLK data written to {}".format(output.name)) -@cli.command('status', short_help='Get status of entity service') -@click.option('--server', type=str, default=DEFAULT_SERVICE_URL, help="Server address including protocol") -@click.option('-o','--output', type=click.File('w'), default='-') -@click.option('-v', '--verbose', default=False, is_flag=True, help="Script is more talkative") +@cli.command("status", short_help="Get status of entity service") +@click.option( + "--server", + type=str, + default=DEFAULT_SERVICE_URL, + help="Server address including protocol", +) +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.option( + "-v", "--verbose", default=False, is_flag=True, help="Script is more talkative" +) def status(server, output, verbose): """Connect to an entity matching server and check the service status. @@ -91,7 +100,7 @@ def status(server, output, verbose): response = requests.get(server + "/api/v1/status") server_status = response.json() log("Response: {}".format(response.status_code)) - log("Status: {}".format(server_status['status'])) + log("Status: {}".format(server_status["status"])) print(json.dumps(server_status), file=output) @@ -115,14 +124,29 @@ def status(server, output, verbose): """ -@cli.command('create', short_help="create a mapping on the entity service") -@click.option('--type', default='permutation_unencrypted_mask', - help='Alternative protocol/view type of the mapping. Default is unencrypted permutation and mask.') -@click.option('--schema', type=click.File('r'), help="Schema to publicly share with participating parties.") -@click.option('--server', type=str, default=DEFAULT_SERVICE_URL, help="Server address including protocol") -@click.option('-o','--output', type=click.File('w'), default='-') -@click.option('-t','--threshold', type=float, default=0.95) -@click.option('-v', '--verbose', default=False, is_flag=True, help="Script is more talkative") + +@cli.command("create", short_help="create a mapping on the entity service") +@click.option( + "--type", + default="permutation_unencrypted_mask", + help="Alternative protocol/view type of the mapping. Default is unencrypted permutation and mask.", +) +@click.option( + "--schema", + type=click.File("r"), + help="Schema to publicly share with participating parties.", +) +@click.option( + "--server", + type=str, + default=DEFAULT_SERVICE_URL, + help="Server address including protocol", +) +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.option("-t", "--threshold", type=float, default=0.95) +@click.option( + "-v", "--verbose", default=False, is_flag=True, help="Script is more talkative" +) def create(type, schema, server, output, threshold, verbose): """Create a new mapping on an entity matching server. @@ -135,14 +159,14 @@ def create(type, schema, server, output, threshold, verbose): log("Entity Matching Server: {}".format(server)) log("Checking server status") - status = requests.get(server + "/api/v1/status").json()['status'] + status = requests.get(server + "/api/v1/status").json()["status"] log("Server Status: {}".format(status)) if schema is not None: schema_object = load_schema(schema) schema_json = json.dumps(schema_object) else: - schema_json = 'NOT PROVIDED' + schema_json = "NOT PROVIDED" log("Schema: {}".format(schema_json)) log("Type: {}".format(type)) @@ -150,11 +174,7 @@ def create(type, schema, server, output, threshold, verbose): log("Creating new mapping") response = requests.post( "{}/api/v1/mappings".format(server), - json={ - 'schema': schema_json, - 'result_type': type, - 'threshold': threshold - } + json={"schema": schema_json, "result_type": type, "threshold": threshold}, ) if response.status_code != 200: @@ -167,13 +187,20 @@ def create(type, schema, server, output, threshold, verbose): print(response.text, file=output) -@cli.command('upload', short_help='upload hashes to entity service') -@click.argument('input', type=click.File('r')) -@click.option('--mapping', help='Server identifier of the mapping') -@click.option('--apikey', help='Authentication API key for the server.') -@click.option('--server', type=str, default=DEFAULT_SERVICE_URL, help="Server address including protocol") -@click.option('-o','--output', type=click.File('w'), default='-') -@click.option('-v', '--verbose', default=False, is_flag=True, help="Script is more talkative") +@cli.command("upload", short_help="upload hashes to entity service") +@click.argument("input", type=click.File("r")) +@click.option("--mapping", help="Server identifier of the mapping") +@click.option("--apikey", help="Authentication API key for the server.") +@click.option( + "--server", + type=str, + default=DEFAULT_SERVICE_URL, + help="Server address including protocol", +) +@click.option("-o", "--output", type=click.File("w"), default="-") +@click.option( + "-v", "--verbose", default=False, is_flag=True, help="Script is more talkative" +) def upload(input, mapping, apikey, server, output, verbose): """Upload CLK data to entity matching server. @@ -188,36 +215,39 @@ def upload(input, mapping, apikey, server, output, verbose): log("Mapping ID: {}".format(mapping)) log("Checking server status") - status = requests.get(server + "/api/v1/status").json()['status'] + status = requests.get(server + "/api/v1/status").json()["status"] log("Status: {}".format(status)) log("Uploading CLK data to the server") response = requests.put( - '{}/api/v1/mappings/{}'.format(server, mapping), + "{}/api/v1/mappings/{}".format(server, mapping), data=input, - headers={ - "Authorization": apikey, - 'content-type': 'application/json' - } + headers={"Authorization": apikey, "content-type": "application/json"}, ) if verbose: log(response.text) - log("When the other party has uploaded their CLKS, you should be able to watch for results") - + log( + "When the other party has uploaded their CLKS, you should be able to watch for results" + ) print(response.text, file=output) - -@cli.command('results', short_help="fetch results from entity service") -@click.option('--mapping', - help='Server identifier of the mapping') -@click.option('--apikey', help='Authentication API key for the server.') -@click.option('-w', '--watch', help='Follow/wait until results are available', is_flag=True) -@click.option('--server', type=str, default=DEFAULT_SERVICE_URL, help="Server address including protocol") -@click.option('-o','--output', type=click.File('w'), default='-') +@cli.command("results", short_help="fetch results from entity service") +@click.option("--mapping", help="Server identifier of the mapping") +@click.option("--apikey", help="Authentication API key for the server.") +@click.option( + "-w", "--watch", help="Follow/wait until results are available", is_flag=True +) +@click.option( + "--server", + type=str, + default=DEFAULT_SERVICE_URL, + help="Server address including protocol", +) +@click.option("-o", "--output", type=click.File("w"), default="-") def results(mapping, apikey, watch, server, output): """ Check to see if results are available for a particular mapping @@ -230,13 +260,13 @@ def results(mapping, apikey, watch, server, output): """ log("Checking server status") - status = requests.get(server + "/api/v1/status").json()['status'] + status = requests.get(server + "/api/v1/status").json()["status"] log("Status: {}".format(status)) def get_result(): return requests.get( - '{}/api/v1/mappings/{}'.format(server, mapping), - headers={"Authorization": apikey} + "{}/api/v1/mappings/{}".format(server, mapping), + headers={"Authorization": apikey}, ) response = get_result() @@ -255,15 +285,15 @@ def get_result(): log(response.text) -@cli.command('benchmark', short_help='carry out a local benchmark') +@cli.command("benchmark", short_help="carry out a local benchmark") def benchmark(): bench.compute_hash_speed(10000) -@cli.command('generate', short_help='generate random pii data for testing') -@click.argument('size', type=int, default=100) -@click.argument('output', type=click.File('w')) -@click.option('--schema', '-s', type=click.File('r'), default=None) +@cli.command("generate", short_help="generate random pii data for testing") +@click.argument("size", type=int, default=100) +@click.argument("output", type=click.File("w")) +@click.option("--schema", "-s", type=click.File("r"), default=None) def generate(size, output, schema): """Generate fake PII data for testing""" pii_data = randomnames.NameList(size) @@ -272,21 +302,22 @@ def generate(size, output, schema): raise NotImplementedError randomnames.save_csv( - pii_data.names, - [f.identifier for f in pii_data.SCHEMA.fields], - output) + pii_data.names, [f.identifier for f in pii_data.SCHEMA.fields], output + ) -@cli.command('generate-default-schema', - short_help='get the default schema used in generated random PII') -@click.argument('output', type=click.Path(writable=True, - readable=False, - resolve_path=True)) +@cli.command( + "generate-default-schema", + short_help="get the default schema used in generated random PII", +) +@click.argument( + "output", type=click.Path(writable=True, readable=False, resolve_path=True) +) def generate_default_schema(output): """Get default schema for fake PII""" - original_path = os.path.join(os.path.dirname(__file__), - 'data', - 'randomnames-schema.json') + original_path = os.path.join( + os.path.dirname(__file__), "data", "randomnames-schema.json" + ) shutil.copyfile(original_path, output) diff --git a/clkhash/clk.py b/clkhash/clk.py index aeb5c2f1..085d464a 100644 --- a/clkhash/clk.py +++ b/clkhash/clk.py @@ -5,8 +5,9 @@ import concurrent.futures import logging import time -from typing import (AnyStr, Callable, Iterable, List, Optional, - Sequence, TextIO, Tuple, TypeVar) +from typing import ( + AnyStr, Callable, Iterable, List, Optional, Sequence, TextIO, Tuple, TypeVar +) from tqdm import tqdm @@ -18,15 +19,16 @@ from clkhash.validate_data import validate_data, validate_header -log = logging.getLogger('clkhash.clk') +log = logging.getLogger("clkhash.clk") CHUNK_SIZE = 1000 -def hash_and_serialize_chunk(chunk_pii_data, # type: Sequence[Sequence[str]] - keys, # type: Sequence[Sequence[bytes]] - schema # type: Schema - ): +def hash_and_serialize_chunk( + chunk_pii_data, + keys, + schema, # type: Sequence[Sequence[str]] # type: Sequence[Sequence[bytes]] # type: Schema +): # type: (...) -> Tuple[List[str], Sequence[int]] """ Generate Bloom filters (ie hash) from chunks of PII then serialize @@ -46,13 +48,14 @@ def hash_and_serialize_chunk(chunk_pii_data, # type: Sequence[Sequence[str]] return clk_data, clk_popcounts -def generate_clk_from_csv(input_f, # type: TextIO - keys, # type: Tuple[AnyStr, AnyStr] - schema, # type: Schema - validate=True, # type: bool - header=True, # type: bool - progress_bar=True # type: bool - ): +def generate_clk_from_csv( + input_f, # type: TextIO + keys, # type: Tuple[AnyStr, AnyStr] + schema, # type: Schema + validate=True, # type: bool + header=True, # type: bool + progress_bar=True, # type: bool +): # type: (...) -> List[str] log.info("Hashing data") @@ -72,40 +75,43 @@ def generate_clk_from_csv(input_f, # type: TextIO if len(line) == len(schema.fields): pii_data.append(tuple([element.strip() for element in line])) else: - raise ValueError("Line had unexpected number of elements. " - "Expected {} but there was {}".format( - len(schema.fields), len(line))) + raise ValueError( + "Line had unexpected number of elements. " + "Expected {} but there was {}".format(len(schema.fields), len(line)) + ) if progress_bar: stats = OnlineMeanVariance() - with tqdm(desc="generating CLKs", total=len(pii_data), unit='clk', unit_scale=True, - postfix={'mean': stats.mean(), 'std': stats.std()}) as pbar: + with tqdm( + desc="generating CLKs", + total=len(pii_data), + unit="clk", + unit_scale=True, + postfix={"mean": stats.mean(), "std": stats.std()}, + ) as pbar: + def callback(tics, clk_stats): stats.update(clk_stats) pbar.set_postfix(mean=stats.mean(), std=stats.std(), refresh=False) pbar.update(tics) - results = generate_clks(pii_data, - schema, - keys, - validate=validate, - callback=callback) + results = generate_clks( + pii_data, schema, keys, validate=validate, callback=callback + ) else: - results = generate_clks(pii_data, - schema, - keys, - validate=validate) + results = generate_clks(pii_data, schema, keys, validate=validate) log.info("Hashing took {:.2f} seconds".format(time.time() - start_time)) return results -def generate_clks(pii_data, # type: Sequence[Sequence[str]] - schema, # type: Schema - keys, # type: Tuple[AnyStr, AnyStr] - validate=True, # type: bool - callback=None # type: Optional[Callable[[int, Sequence[int]], None]] - ): +def generate_clks( + pii_data, # type: Sequence[Sequence[str]] + schema, # type: Schema + keys, # type: Tuple[AnyStr, AnyStr] + validate=True, # type: bool + callback=None, # type: Optional[Callable[[int, Sequence[int]], None]] +): # type: (...) -> List[str] # generate two keys for each identifier @@ -116,7 +122,8 @@ def generate_clks(pii_data, # type: Sequence[Sequence[str]] salt=schema.hashing_globals.kdf_salt, info=schema.hashing_globals.kdf_info, kdf=schema.hashing_globals.kdf_type, - hash_algo=schema.hashing_globals.kdf_hash) + hash_algo=schema.hashing_globals.kdf_hash, + ) if validate: validate_data(schema.fields, pii_data) @@ -129,11 +136,11 @@ def generate_clks(pii_data, # type: Sequence[Sequence[str]] # Compute Bloom filter from the chunks and then serialise it with concurrent.futures.ProcessPoolExecutor() as executor: for chunk in chunks(pii_data, chunk_size): - future = executor.submit( - hash_and_serialize_chunk, - chunk, key_lists, schema,) + future = executor.submit(hash_and_serialize_chunk, chunk, key_lists, schema) if callback is not None: - future.add_done_callback(lambda f: callback(len(f.result()[0]), f.result()[1])) + future.add_done_callback( + lambda f: callback(len(f.result()[0]), f.result()[1]) + ) futures.append(future) results = [] @@ -144,7 +151,7 @@ def generate_clks(pii_data, # type: Sequence[Sequence[str]] return results -T = TypeVar('T') # Declare generic type variable +T = TypeVar("T") # Declare generic type variable def chunks(seq, chunk_size): diff --git a/clkhash/field_formats.py b/clkhash/field_formats.py index 77588417..1ae5e663 100644 --- a/clkhash/field_formats.py +++ b/clkhash/field_formats.py @@ -47,32 +47,33 @@ class FieldHashingProperties(object): :ivar float weight: Controls the weight of the field in the Bloom filter. """ - _DEFAULT_ENCODING = 'utf-8' + _DEFAULT_ENCODING = "utf-8" _DEFAULT_POSITIONAL = False _DEFAULT_WEIGHT = 1 - def __init__(self, - ngram, # type: int - encoding=_DEFAULT_ENCODING, # type: str - weight=_DEFAULT_WEIGHT, # type: Union[int, float] - positional=_DEFAULT_POSITIONAL # type: bool - ): + def __init__( + self, + ngram, # type: int + encoding=_DEFAULT_ENCODING, # type: str + weight=_DEFAULT_WEIGHT, # type: Union[int, float] + positional=_DEFAULT_POSITIONAL, # type: bool + ): # type: (...) -> None """ Make a :class:`FieldHashingProperties` object, setting it attributes to values specified in keyword arguments. """ if ngram not in range(3): - msg = 'ngram is {} but is expected to be 0, 1, or 2.' + msg = "ngram is {} but is expected to be 0, 1, or 2." raise ValueError(msg.format(ngram)) try: - ''.encode(encoding) + "".encode(encoding) except LookupError as e: - msg = '{} is not a valid Python encoding.' + msg = "{} is not a valid Python encoding." raise_from(ValueError(msg.format(encoding)), e) if weight < 0: - msg = 'weight should be non-negative but is {}.' + msg = "weight should be non-negative but is {}." raise ValueError(msg.format(weight)) self.ngram = ngram @@ -94,11 +95,12 @@ def from_json_dict(cls, json_dict): :return: A :class:`FieldHashingProperties` instance. """ return cls( - ngram=json_dict['ngram'], + ngram=json_dict["ngram"], positional=json_dict.get( - 'positional', FieldHashingProperties._DEFAULT_POSITIONAL), - weight=json_dict.get( - 'weight', FieldHashingProperties._DEFAULT_WEIGHT)) + "positional", FieldHashingProperties._DEFAULT_POSITIONAL + ), + weight=json_dict.get("weight", FieldHashingProperties._DEFAULT_WEIGHT), + ) @add_metaclass(abc.ABCMeta) @@ -112,11 +114,13 @@ class FieldSpec(object): :ivar FieldHashingProperties hashing_properties: The properties for hashing. """ - def __init__(self, - identifier, # type: str - hashing_properties, # type: FieldHashingProperties - description=None # type: Optional[str] - ): + + def __init__( + self, + identifier, # type: str + hashing_properties, # type: FieldHashingProperties + description=None, # type: Optional[str] + ): # type: (...) -> None """ Make a FieldSpec object, setting it attributes to values specified in keyword arguments. @@ -138,10 +142,11 @@ def from_json_dict(cls, field_dict): dictionary contains invalid values. Exactly what that means is decided by the subclasses. """ - identifier = field_dict['identifier'] - description = field_dict['format'].get('description') + identifier = field_dict["identifier"] + description = field_dict["format"].get("description") hashing_properties = FieldHashingProperties.from_json_dict( - field_dict['hashing']) + field_dict["hashing"] + ) result = cls.__new__(cls) # type: ignore result.identifier = identifier @@ -167,8 +172,11 @@ def validate(self, str_in): try: str_in.encode(encoding=self.hashing_properties.encoding) except UnicodeEncodeError as e: - msg = ("Expected entry that can be encoded in {}. Read '{}'." - .format(self.hashing_properties.encoding, str_in)) + msg = ( + "Expected entry that can be encoded in {}. Read '{}'.".format( + self.hashing_properties.encoding, str_in + ) + ) raise_from(InvalidEntryError(msg), e) @@ -197,47 +205,55 @@ class StringSpec(FieldSpec): if there is no maximum length. Present only if the specification is not regex-based. """ - _DEFAULT_CASE = 'mixed' + _DEFAULT_CASE = "mixed" _DEFAULT_MIN_LENGTH = 0 - _PERMITTED_CASE_STYLES = {'lower', 'upper', 'mixed'} - - def __init__(self, - identifier, # type: str - hashing_properties, # type: FieldHashingProperties - description=None, # type: str - regex=None, # type: Optional[str] - case=_DEFAULT_CASE, # type: str - min_length=_DEFAULT_MIN_LENGTH, # type: Optional[int] - max_length=None # type: Optional[int] - ): + _PERMITTED_CASE_STYLES = {"lower", "upper", "mixed"} + + def __init__( + self, + identifier, # type: str + hashing_properties, # type: FieldHashingProperties + description=None, # type: str + regex=None, # type: Optional[str] + case=_DEFAULT_CASE, # type: str + min_length=_DEFAULT_MIN_LENGTH, # type: Optional[int] + max_length=None, # type: Optional[int] + ): # type: (...) -> None """ Make a StringSpec object, setting it attributes to values specified in keyword arguments. - """ - super().__init__(identifier=identifier, - description=description, - hashing_properties=hashing_properties) + """ + super().__init__( + identifier=identifier, + description=description, + hashing_properties=hashing_properties, + ) regex_based = regex is not None - if regex_based and (case != self._DEFAULT_CASE - or min_length != self._DEFAULT_MIN_LENGTH - or max_length is not None): - msg = ('regex cannot be passed along with case, min_length, or' - ' max_length.') + if ( + regex_based + and ( + case != self._DEFAULT_CASE + or min_length != self._DEFAULT_MIN_LENGTH + or max_length is not None + ) + ): + msg = ( + "regex cannot be passed along with case, min_length, or" " max_length." + ) raise ValueError(msg) if case not in self._PERMITTED_CASE_STYLES: - msg = ("the case is {}, but should be 'lower', 'upper', or" - "'mixed'") + msg = ("the case is {}, but should be 'lower', 'upper', or" "'mixed'") raise ValueError(msg.format(case)) if regex_based and min_length < 0: - msg = ('min_length must be non-negative, but is {}') + msg = ("min_length must be non-negative, but is {}") raise ValueError(msg.format(min_length)) if regex_based and max_length is not None and max_length <= 0: - msg = ('max_length must be positive, but is {}') + msg = ("max_length must be positive, but is {}") raise ValueError(msg.format(max_length)) if regex_based: @@ -269,14 +285,13 @@ def from_json_dict(cls, json_dict): :raises InvalidSchemaError: When a regular expression is provided but is not a valid pattern. """ - result = cast(StringSpec, # Go away, Mypy. - super().from_json_dict(json_dict)) + result = cast(StringSpec, super().from_json_dict(json_dict)) # Go away, Mypy. - format_ = json_dict['format'] - result.hashing_properties.encoding = format_['encoding'] + format_ = json_dict["format"] + result.hashing_properties.encoding = format_["encoding"] - if 'pattern' in format_: - pattern = format_['pattern'] + if "pattern" in format_: + pattern = format_["pattern"] try: result.regex = re_compile_full(pattern) except (SyntaxError, re.error) as e: @@ -285,9 +300,9 @@ def from_json_dict(cls, json_dict): result.regex_based = True else: - result.case = format_.get('case', StringSpec._DEFAULT_CASE) - result.min_length = format_.get('minLength') - result.max_length = format_.get('maxLength') + result.case = format_.get("case", StringSpec._DEFAULT_CASE) + result.min_length = format_.get("minLength") + result.max_length = format_.get("maxLength") result.regex_based = False return result @@ -315,34 +330,40 @@ def validate(self, str_in): match = self.regex.match(str_in) if match is None: raise InvalidEntryError( - 'Expected entry that conforms to regular expression ' - "'{}'. Read '{}'.".format(self.regex.pattern, str_in)) + "Expected entry that conforms to regular expression " + "'{}'. Read '{}'.".format(self.regex.pattern, str_in) + ) else: str_len = len(str_in) if self.min_length is not None and str_len < self.min_length: raise InvalidEntryError( - 'Expected string length of at least {}. Read string of ' - 'length {}.'.format(self.min_length, str_len)) + "Expected string length of at least {}. Read string of " + "length {}.".format(self.min_length, str_len) + ) if self.max_length is not None and str_len > self.max_length: raise InvalidEntryError( - 'Expected string length of at most {}. Read string of ' - 'length {}.'.format(self.max_length, str_len)) + "Expected string length of at most {}. Read string of " + "length {}.".format(self.max_length, str_len) + ) - if self.case == 'upper': + if self.case == "upper": if str_in.upper() != str_in: raise InvalidEntryError( - 'Expected upper case string. Read {}.'.format(str_in)) - elif self.case == 'lower': + "Expected upper case string. Read {}.".format(str_in) + ) + + elif self.case == "lower": if str_in.lower() != str_in: raise InvalidEntryError( - 'Expected lower case string. Read {}.'.format(str_in)) - elif self.case == 'mixed': + "Expected lower case string. Read {}.".format(str_in) + ) + + elif self.case == "mixed": pass else: - raise ValueError( - 'Invalid case property {}.'.format(self.case)) + raise ValueError("Invalid case property {}.".format(self.case)) class IntegerSpec(FieldSpec): @@ -356,21 +377,24 @@ class IntegerSpec(FieldSpec): _DEFAULT_MINIMUM = 0 - def __init__(self, - identifier, # type: str - hashing_properties, # type: FieldHashingProperties - description=None, # type: str - minimum=_DEFAULT_MINIMUM, # int - maximum=None, # Optional[int] - **kwargs # Dict[str, Any] - ): + def __init__( + self, + identifier, # type: str + hashing_properties, # type: FieldHashingProperties + description=None, # type: str + minimum=_DEFAULT_MINIMUM, # int + maximum=None, # Optional[int] + **kwargs # Dict[str, Any] + ): # type: (...) -> None """ Make a IntegerSpec object, setting it attributes to values specified in keyword arguments. """ - super().__init__(identifier=identifier, - description=description, - hashing_properties=hashing_properties) + super().__init__( + identifier=identifier, + description=description, + hashing_properties=hashing_properties, + ) self.minimum = minimum self.maximum = maximum @@ -388,12 +412,11 @@ def from_json_dict(cls, json_dict): :param dict json_dict: The properties dictionary. """ - result = cast(IntegerSpec, # For Mypy. - super().from_json_dict(json_dict)) + result = cast(IntegerSpec, super().from_json_dict(json_dict)) # For Mypy. - format_ = json_dict['format'] - result.minimum = format_.get('minimum', cls._DEFAULT_MINIMUM) - result.maximum = format_.get('maximum') + format_ = json_dict["format"] + result.minimum = format_.get("minimum", cls._DEFAULT_MINIMUM) + result.maximum = format_.get("maximum") return result @@ -416,17 +439,23 @@ def validate(self, str_in): try: value = int(str_in, base=10) except ValueError as e: - msg = 'Invalid integer. Read {}.'.format(str_in) + msg = "Invalid integer. Read {}.".format(str_in) raise_from(InvalidEntryError(msg), e) if value < self.minimum: - msg = ('Expected integer value of at least {}. Read {}.' - .format(self.minimum, value)) + msg = ( + "Expected integer value of at least {}. Read {}.".format( + self.minimum, value + ) + ) raise InvalidEntryError(msg) if self.maximum is not None and value > self.maximum: - msg = ('Expected integer value of at most {}. Read {}.' - .format(self.maximum, value)) + msg = ( + "Expected integer value of at most {}. Read {}.".format( + self.maximum, value + ) + ) raise InvalidEntryError(msg) @@ -439,26 +468,29 @@ class DateSpec(FieldSpec): :ivar str format: The format of the date. """ - _PERMITTED_FORMATS = {'rfc3339'} - _RFC3339_REGEX = re_compile_full(r'\d\d\d\d-\d\d-\d\d') - _RFC3339_FORMAT = '%Y-%m-%d' - - def __init__(self, - identifier, # type: str - hashing_properties, # type: FieldHashingProperties - format, # type: str - description=None # type: str - ): + _PERMITTED_FORMATS = {"rfc3339"} + _RFC3339_REGEX = re_compile_full(r"\d\d\d\d-\d\d-\d\d") + _RFC3339_FORMAT = "%Y-%m-%d" + + def __init__( + self, + identifier, # type: str + hashing_properties, # type: FieldHashingProperties + format, # type: str + description=None, # type: str + ): # type: (...) -> None """ Make a DateSpec object, setting it attributes to values specified in keyword arguments. """ - super().__init__(identifier=identifier, - description=description, - hashing_properties=hashing_properties) + super().__init__( + identifier=identifier, + description=description, + hashing_properties=hashing_properties, + ) if format not in self._PERMITTED_FORMATS: - msg = 'No validation for date format: {}.'.format(format) + msg = "No validation for date format: {}.".format(format) raise NotImplementedError(msg) self.format = format @@ -476,11 +508,10 @@ def from_json_dict(cls, json_dict): :param json_dict: The properties dictionary. """ - result = cast(DateSpec, # For Mypy. - super().from_json_dict(json_dict)) + result = cast(DateSpec, super().from_json_dict(json_dict)) # For Mypy. - format_ = json_dict['format'] - result.format = format_['format'] + format_ = json_dict["format"] + result.format = format_["format"] return result @@ -500,19 +531,19 @@ def validate(self, str_in): """ super().validate(str_in) - if self.format == 'rfc3339': + if self.format == "rfc3339": if self._RFC3339_REGEX.match(str_in) is None: - msg = ('Date expected to conform to RFC3339. Read {}.' - .format(str_in)) + msg = ("Date expected to conform to RFC3339. Read {}.".format(str_in)) raise InvalidEntryError(msg) + try: datetime.strptime(str_in, self._RFC3339_FORMAT) except ValueError as e: - msg = 'Invalid date. Read {}.'.format(str_in) + msg = "Invalid date. Read {}.".format(str_in) raise_from(InvalidEntryError(msg), e) else: - msg = 'No validation for date format: {}.'.format(self.format) + msg = "No validation for date format: {}.".format(self.format) raise NotImplementedError(msg) @@ -523,19 +554,23 @@ class EnumSpec(FieldSpec): :ivar values: The set of permitted values. """ - def __init__(self, - identifier, # type: str - hashing_properties, # type: FieldHashingProperties - values, # type: Iterable[str] - description=None # type: str - ): + + def __init__( + self, + identifier, # type: str + hashing_properties, # type: FieldHashingProperties + values, # type: Iterable[str] + description=None, # type: str + ): # type: (...) -> None """ Make a EnumSpec object, setting it attributes to values specified in keyword arguments. """ - super().__init__(identifier=identifier, - description=description, - hashing_properties=hashing_properties) + super().__init__( + identifier=identifier, + description=description, + hashing_properties=hashing_properties, + ) self.values = set(values) @@ -550,11 +585,12 @@ def from_json_dict(cls, json_dict): addition, it must contain a `'hashing'` key, whose contents are passed to :class:`FieldHashingProperties`. """ - result = cast(EnumSpec, # Appease the gods of Mypy. - super().from_json_dict(json_dict)) + result = cast( + EnumSpec, super().from_json_dict(json_dict) + ) # Appease the gods of Mypy. - format_ = json_dict['format'] - result.values = set(format_['values']) + format_ = json_dict["format"] + result.values = set(format_["values"]) return result @@ -573,8 +609,9 @@ def validate(self, str_in): super().validate(str_in) if str_in not in self.values: - msg = ('Expected enum value is one of {}. Read {}.' - .format(self.values, str_in)) + msg = ( + "Expected enum value is one of {}. Read {}.".format(self.values, str_in) + ) raise InvalidEntryError(msg) @@ -582,9 +619,8 @@ class Ignore(FieldSpec): """ represent a field which will be ignored throughout the clk processing. """ - def __init__(self, - identifier=None # type: str - ): + + def __init__(self, identifier=None): # type: str # type: (...) -> None super().__init__(identifier, FieldHashingProperties(ngram=0, weight=0)) @@ -594,10 +630,7 @@ def validate(self, str_in): # Map type string (as defined in master schema) to FIELD_TYPE_MAP = { - 'string': StringSpec, - 'integer': IntegerSpec, - 'date': DateSpec, - 'enum': EnumSpec, + "string": StringSpec, "integer": IntegerSpec, "date": DateSpec, "enum": EnumSpec } @@ -609,8 +642,9 @@ def spec_from_json_dict(json_dict): :returns: An initialised instance of the appropriate FieldSpec subclass. """ - if 'ignored' in json_dict: - return Ignore(json_dict['identifier']) - type_str = json_dict['format']['type'] + if "ignored" in json_dict: + return Ignore(json_dict["identifier"]) + + type_str = json_dict["format"]["type"] spec_type = cast(FieldSpec, FIELD_TYPE_MAP[type_str]) return spec_type.from_json_dict(json_dict) diff --git a/clkhash/key_derivation.py b/clkhash/key_derivation.py index 7abce22f..9a15232c 100644 --- a/clkhash/key_derivation.py +++ b/clkhash/key_derivation.py @@ -12,14 +12,15 @@ class HKDFconfig: - supported_hash_algos = 'SHA256', 'SHA512' - - def __init__(self, - master_secret, # type: bytes - salt=None, # type: Optional[bytes] - info=None, # type: Optional[bytes] - hash_algo='SHA256' # type: str - ): # type: (...) -> None + supported_hash_algos = "SHA256", "SHA512" + + def __init__( + self, + master_secret, # type: bytes + salt=None, # type: Optional[bytes] + info=None, # type: Optional[bytes] + hash_algo="SHA256", # type: str + ): # type: (...) -> None """ The parameters for the HDKF are defined as follows: @@ -73,14 +74,18 @@ def __init__(self, if hash_algo in HKDFconfig.supported_hash_algos: self.hash_algo = hash_algo else: - raise ValueError('hash algorithm "{}" is not supported. Has to be one of {}'.format(hash_algo, - HKDFconfig.supported_hash_algos)) + raise ValueError( + 'hash algorithm "{}" is not supported. Has to be one of {}'.format( + hash_algo, HKDFconfig.supported_hash_algos + ) + ) @staticmethod def check_is_bytes(value): # type: (Any) -> bytes if isinstance(value, bytes): return value + else: raise TypeError('provided value is not of type "bytes"') @@ -89,6 +94,7 @@ def check_is_bytes_or_none(value): # type: (Any) -> Optional[bytes] if value is None: return value + else: return HKDFconfig.check_is_bytes(value) @@ -104,14 +110,17 @@ def hkdf(hkdf_config, num_keys, key_size=DEFAULT_KEY_SIZE): :param key_size: the size of the produced keys :return: Derived keys """ - hash_dict = { - 'SHA256': hashes.SHA256, - 'SHA512': hashes.SHA512 - } + hash_dict = {"SHA256": hashes.SHA256, "SHA512": hashes.SHA512} if not isinstance(hkdf_config, HKDFconfig): raise TypeError('provided config has to be of type "HKDFconfig"') - hkdf = HKDF(algorithm=hash_dict[hkdf_config.hash_algo](), length=num_keys * key_size, salt=hkdf_config.salt, - info=hkdf_config.info, backend=default_backend()) + + hkdf = HKDF( + algorithm=hash_dict[hkdf_config.hash_algo](), + length=num_keys * key_size, + salt=hkdf_config.salt, + info=hkdf_config.info, + backend=default_backend() + ) # hkdf.derive returns a block of num_keys * key_size bytes which we divide up into num_keys chunks, # each of size key_size keybytes = hkdf.derive(hkdf_config.master_secret) @@ -119,14 +128,15 @@ def hkdf(hkdf_config, num_keys, key_size=DEFAULT_KEY_SIZE): return keys -def generate_key_lists(master_secrets, # type: Sequence[Union[bytes, str]] - num_identifier, # type: int - key_size=DEFAULT_KEY_SIZE, # type: int - salt=None, # type: Optional[bytes] - info=None, # type: Optional[bytes] - kdf='HKDF', # type: str - hash_algo='SHA256' # type: str - ): +def generate_key_lists( + master_secrets, # type: Sequence[Union[bytes, str]] + num_identifier, # type: int + key_size=DEFAULT_KEY_SIZE, # type: int + salt=None, # type: Optional[bytes] + info=None, # type: Optional[bytes] + kdf="HKDF", # type: str + hash_algo="SHA256", # type: str +): # type: (...) -> Tuple[Tuple[bytes, ...], ...] """ Generates a derived key for each identifier for each master secret using a key derivation function (KDF). @@ -154,17 +164,25 @@ def generate_key_lists(master_secrets, # type: Sequence[Union[bytes if isinstance(key, bytes): keys.append(key) else: - keys.append(key.encode('UTF-8')) + keys.append(key.encode("UTF-8")) except AttributeError: - raise TypeError("provided 'master_secrets' have to be either of type bytes or strings.") - if kdf == 'HKDF': - key_lists = [hkdf(HKDFconfig(key, salt=salt, info=info, - hash_algo=hash_algo), - num_identifier, - key_size) - for key in keys] + raise TypeError( + "provided 'master_secrets' have to be either of type bytes or strings." + ) + + if kdf == "HKDF": + key_lists = [ + hkdf( + HKDFconfig(key, salt=salt, info=info, hash_algo=hash_algo), + num_identifier, + key_size, + ) + for key in keys + ] # regroup such that we get a tuple of keys for each identifier return tuple(zip(*key_lists)) - if kdf == 'legacy': + + if kdf == "legacy": return tuple([tuple(keys) for _ in range(num_identifier)]) + raise ValueError('kdf: "{}" is not supported.'.format(kdf)) diff --git a/clkhash/randomnames.py b/clkhash/randomnames.py index 257f5f20..1754d7f6 100644 --- a/clkhash/randomnames.py +++ b/clkhash/randomnames.py @@ -26,24 +26,25 @@ from clkhash.schema import Schema from clkhash.field_formats import FieldSpec + def load_csv_data(resource_name): # type: (str) -> List[str] """Loads a specified CSV data file and returns the first column as a Python list """ - data_bytes = pkgutil.get_data('clkhash', 'data/{}'.format(resource_name)) + data_bytes = pkgutil.get_data("clkhash", "data/{}".format(resource_name)) if data_bytes is None: raise ValueError("No data resource found with name {}".format(resource_name)) + else: - data = data_bytes.decode('utf8') + data = data_bytes.decode("utf8") reader = csv.reader(data.splitlines()) next(reader, None) # skip the headers return [row[0] for row in reader] -def save_csv(data, # type: Iterable[Tuple[Union[str, int], ...]] - headers, # type: Iterable[str] - file # type: TextIO - ): +def save_csv( + data, headers, file +): # type: Iterable[Tuple[Union[str, int], ...]] # type: Iterable[str] # type: TextIO # type: (...) -> None """ Output generated data to file as CSV with header. @@ -53,7 +54,7 @@ def save_csv(data, # type: Iterable[Tuple[Union[str, int], ...]] :param file: A writeable stream in which to write the CSV """ - print(','.join(headers), file=file) + print(",".join(headers), file=file) writer = csv.writer(file) writer.writerows(data) @@ -77,9 +78,9 @@ class NameList: """ List of randomly generated names. """ - with open(os.path.join(os.path.dirname(__file__), - 'data', - 'randomnames-schema.json')) as f: + with open( + os.path.join(os.path.dirname(__file__), "data", "randomnames-schema.json") + ) as f: SCHEMA = Schema.from_json_file(f) del f @@ -106,17 +107,18 @@ def generate_random_person(self, n): tuple - (id: int, name: str('First Last'), birthdate: str('DD/MM/YYYY'), sex: str('M' | 'F') ) """ for i in range(n): - sex = 'M' if random.random() > 0.5 else 'F' - dob = random_date(self.earliest_birthday, self.latest_birthday).strftime("%Y/%m/%d") - first_name = random.choice(self.all_male_first_names) if sex == 'M' else random.choice(self.all_female_first_names) + sex = "M" if random.random() > 0.5 else "F" + dob = random_date(self.earliest_birthday, self.latest_birthday).strftime( + "%Y/%m/%d" + ) + first_name = random.choice( + self.all_male_first_names + ) if sex == "M" else random.choice( + self.all_female_first_names + ) last_name = random.choice(self.all_last_names) - yield ( - str(i), - first_name + ' ' + last_name, - dob, - sex - ) + yield (str(i), first_name + " " + last_name, dob, sex) def load_names(self): # type: () -> None @@ -128,9 +130,9 @@ def load_names(self): """ - self.all_male_first_names = load_csv_data('male-first-names.csv') - self.all_female_first_names = load_csv_data('female-first-names.csv') - self.all_last_names = load_csv_data('CSV_Database_of_Last_Names.csv') + self.all_male_first_names = load_csv_data("male-first-names.csv") + self.all_female_first_names = load_csv_data("female-first-names.csv") + self.all_last_names = load_csv_data("CSV_Database_of_Last_Names.csv") def generate_subsets(self, sz, overlap=0.8): """ @@ -146,7 +148,10 @@ def generate_subsets(self, sz, overlap=0.8): notoverlap = sz - int(math.floor(overlap * sz)) total_sz = sz + notoverlap if total_sz > len(self.names): - raise ValueError('Requested subset size and overlap demands more ' - + 'than the number of available names') + raise ValueError( + "Requested subset size and overlap demands more " + + "than the number of available names" + ) + sset = random.sample(self.names, total_sz) return sset[:sz], sset[notoverlap:] diff --git a/clkhash/stats.py b/clkhash/stats.py index cc8ecbf5..478d5beb 100644 --- a/clkhash/stats.py +++ b/clkhash/stats.py @@ -10,9 +10,7 @@ def __init__(self): self.n = 0 # type: int self.S = 0 # type: float - def update(self, - x # type: Sequence[Union[int, float]] - ): + def update(self, x): # type: Sequence[Union[int, float]] # type: (...) -> None """ updates the statistics with the given list of numbers @@ -24,15 +22,20 @@ def update(self, :return: nothing """ if any(math.isnan(float(i)) or math.isinf(float(i)) for i in x): - raise ValueError('input contains non-finite numbers like "nan" or "+/- inf"') + raise ValueError( + 'input contains non-finite numbers like "nan" or "+/- inf"' + ) + t = sum(x) m = float(len(x)) norm_t = t / m - S = sum((xi - norm_t)**2 for xi in x) + S = sum((xi - norm_t) ** 2 for xi in x) if self.n == 0: self.S = self.S + S else: - self.S = self.S + S + self.n / (m * (m + self.n)) * (m/self.n * self.t - t)**2 + self.S = self.S + S + self.n / (m * (m + self.n)) * ( + m / self.n * self.t - t + ) ** 2 self.t = self.t + t self.n = self.n + len(x) @@ -45,6 +48,7 @@ def mean(self): """ if self.n == 0: return 0 + return self.t / float(self.n) def variance(self): @@ -56,6 +60,7 @@ def variance(self): """ if self.n <= 1: return 0 + return self.S / (self.n - 1.) def std(self): diff --git a/clkhash/tokenizer.py b/clkhash/tokenizer.py index 5a41f170..8efc7e7d 100644 --- a/clkhash/tokenizer.py +++ b/clkhash/tokenizer.py @@ -25,25 +25,23 @@ def tokenize(n, positional, word, ignore=None): :return: Tuple of n-gram strings. """ if n < 0: - raise ValueError('`n` in `n`-gram must be non-negative.') + raise ValueError("`n` in `n`-gram must be non-negative.") if ignore is not None: - word = word.replace(ignore, '') + word = word.replace(ignore, "") if n > 1: - word = ' {} '.format(word) + word = " {} ".format(word) if positional: # These are 1-indexed. - return ('{} {}'.format(i + 1, word[i:i+n]) - for i in range(len(word) - n + 1)) - else: - return (word[i:i+n] for i in range(len(word) - n + 1)) + return ("{} {}".format(i + 1, word[i:i + n]) for i in range(len(word) - n + 1)) + else: + return (word[i:i + n] for i in range(len(word) - n + 1)) -def get_tokenizer(hash_settings # type: field_formats.FieldHashingProperties - ): +def get_tokenizer(hash_settings): # type: field_formats.FieldHashingProperties # type: (...) -> Callable[[Text], Iterable[Text]] """ Get tokeniser function from the hash settings. diff --git a/clkhash/validate_data.py b/clkhash/validate_data.py index 05a81f5a..6646dcfa 100644 --- a/clkhash/validate_data.py +++ b/clkhash/validate_data.py @@ -23,9 +23,9 @@ class FormatError(ValueError): """ -def validate_data(fields, # type: Sequence[FieldSpec] - data # type: Sequence[Sequence[str]] - ): +def validate_data( + fields, data +): # type: Sequence[FieldSpec] # type: Sequence[Sequence[str]] # type: (...) -> None """ Validate the `data` entries according to the specification in `fields`. @@ -42,20 +42,21 @@ def validate_data(fields, # type: Sequence[FieldSpec] for row in data: if len(validators) != len(row): - msg = 'Row has {} entries when {} are expected.'.format( - len(row), len(validators)) + msg = "Row has {} entries when {} are expected.".format( + len(row), len(validators) + ) raise FormatError(msg) for entry, v in zip(row, validators): try: v(entry) except InvalidEntryError as e: - raise_from(EntryError('Invalid entry.'), e) + raise_from(EntryError("Invalid entry."), e) -def validate_header(fields, # type: Sequence[FieldSpec] - column_names # type: Sequence[str] - ): +def validate_header( + fields, column_names +): # type: Sequence[FieldSpec] # type: Sequence[str] # type: (...) -> None """ Validate the `column_names` according to the specification in `fields`. @@ -67,12 +68,14 @@ def validate_header(fields, # type: Sequence[FieldSpec] identifiers don't match the specification. """ if len(fields) != len(column_names): - msg = 'Header has {} columns when {} are expected'.format( - len(column_names), len(fields)) + msg = "Header has {} columns when {} are expected".format( + len(column_names), len(fields) + ) raise FormatError(msg) for f, column in zip(fields, column_names): if f.identifier != column: msg = "Column has identifier '{}' when '{}' is expected".format( - column, f.identifier) + column, f.identifier + ) raise FormatError(msg)