diff --git a/phe/util.py b/phe/util.py index a2a6ae8..19b346b 100644 --- a/phe/util.py +++ b/phe/util.py @@ -14,9 +14,9 @@ # along with pyphe. If not, see . import os -import sys import random -import base64 +from base64 import urlsafe_b64encode, urlsafe_b64decode +from binascii import hexlify, unhexlify try: import gmpy2 @@ -85,42 +85,30 @@ def getprimeover(N): raise NotImplementedError("No pure python implementation sorry") -# b64 utils from https://github.com/GehirnInc/python-jwt/blob/master/jwt/utils.py -if sys.version_info[0] == 3: - ord = lambda i: i +# base64 utils from jwcrypto +def base64url_encode(payload): + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') + encode = urlsafe_b64encode(payload) + return encode.decode('utf-8').rstrip('=') -def b64_encode(source): - if not isinstance(source, bytes): - source = source.encode('ascii') - encoded = base64.urlsafe_b64encode(source).replace(b'=', b'') - return str(encoded.decode('ascii')) - - -def b64_decode(source): - if not isinstance(source, bytes): - source = source.encode('ascii') - - source += b'=' * (4 - (len(source) % 4)) - return base64.urlsafe_b64decode(source) +def base64url_decode(payload): + l = len(payload) % 4 + if l == 2: + payload += '==' + elif l == 3: + payload += '=' + elif l != 0: + raise ValueError('Invalid base64 string') + return urlsafe_b64decode(payload.encode('utf-8')) def base64_to_int(source): - if not isinstance(source, bytes): - source = source.encode('ascii') - - result = 0 - for b in b64_decode(source): - result = (result << 8) + ord(b) - - return result + int(hexlify(base64url_decode(source)), 16) def int_to_base64(source): - result_reversed = [] - while source: - source, remainder = divmod(source, 256) - result_reversed.append(remainder) - - return b64_encode(bytes(bytearray(reversed(result_reversed)))) + I = hex(source).rstrip("L").lstrip("0x") + return base64url_encode(unhexlify((len(I) % 2) * '0' + I))