test
", "input_description": "test", - "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low", - "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {}, - "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C", - "spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85", - "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, - "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", - "input_size": 0, "score": 0}], - "io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"}, - "share_submission": False, - "rule_type": "ACM", "hint": "test
", "source": "test"} +DEFAULT_PROBLEM_DATA = { + '_id': 'A-110', + 'title': 'test', + 'description': 'test
', + 'input_description': 'test', + 'output_description': 'test', + 'time_limit': 1000, + 'memory_limit': 256, + 'difficulty': 'Low', + 'visible': True, + 'tags': ['test'], + 'languages': ['C', 'C++', 'Java', 'Python2'], + 'template': {}, + 'samples': [{'input': 'test', 'output': 'test'}], + 'spj': False, + 'spj_language': 'C', + 'spj_code': '', + 'spj_compile_ok': True, + 'test_case_id': '499b26290cc7994e0b497212e842ea85', + 'test_case_score': [ + { + 'output_name': '1.out', + 'input_name': '1.in', + 'output_size': 0, + 'stripped_output_md5': 'd41d8cd98f00b204e9800998ecf8427e', + 'input_size': 0, + 'score': 0, + } + ], + 'io_mode': { + 'io_mode': ProblemIOMode.standard, + 'input': 'input.txt', + 'output': 'output.txt', + }, + 'share_submission': False, + 'rule_type': 'ACM', + 'hint': 'test
', + 'source': 'test', +} class ProblemCreateTestBase(APITestCase): @staticmethod def add_problem(problem_data, created_by): data = copy.deepcopy(problem_data) - if data["spj"]: - if not data["spj_language"] or not data["spj_code"]: - raise ValueError("Invalid spj") - data["spj_version"] = hashlib.md5( - (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + if data['spj']: + if not data['spj_language'] or not data['spj_code']: + raise ValueError('Invalid spj') + data['spj_version'] = hashlib.md5( + (data['spj_language'] + ':' + data['spj_code']).encode('utf-8') + ).hexdigest() else: - data["spj_language"] = None - data["spj_code"] = None - if data["rule_type"] == ProblemRuleType.OI: + data['spj_language'] = None + data['spj_code'] = None + if data['rule_type'] == ProblemRuleType.OI: total_score = 0 - for item in data["test_case_score"]: - if item["score"] <= 0: - raise ValueError("invalid score") + for item in data['test_case_score']: + if item['score'] <= 0: + raise ValueError('invalid score') else: - total_score += item["score"] - data["total_score"] = total_score - data["created_by"] = created_by - tags = data.pop("tags") + total_score += item['score'] + data['total_score'] = total_score + data['created_by'] = created_by + tags = data.pop('tags') - data["languages"] = list(data["languages"]) + data['languages'] = list(data['languages']) problem = Problem.objects.create(**data) @@ -68,72 +97,85 @@ def add_problem(problem_data, created_by): class ProblemTagListAPITest(APITestCase): def test_get_tag_list(self): - ProblemTag.objects.create(name="name1") - ProblemTag.objects.create(name="name2") - resp = self.client.get(self.reverse("problem_tag_list_api")) + ProblemTag.objects.create(name='name1') + ProblemTag.objects.create(name='name2') + resp = self.client.get(self.reverse('problem_tag_list_api')) self.assertSuccess(resp) class TestCaseUploadAPITest(APITestCase): def setUp(self): self.api = TestCaseAPI() - self.url = self.reverse("test_case_api") + self.url = self.reverse('test_case_api') self.create_super_admin() def test_filter_file_name(self): - self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False), - ["1.in", "1.out"]) - self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), []) - - self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"]) - self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), []) + self.assertEqual( + self.api.filter_name_list( + ['1.in', '1.out', '2.in', '.DS_Store'], spj=False + ), + ['1.in', '1.out'], + ) + self.assertEqual(self.api.filter_name_list(['2.in', '2.out'], spj=False), []) + + self.assertEqual( + self.api.filter_name_list(['1.in', '1.out', '2.in'], spj=True), + ['1.in', '2.in'], + ) + self.assertEqual(self.api.filter_name_list(['2.in', '3.in'], spj=True), []) def make_test_case_zip(self): - base_dir = os.path.join("/tmp", "test_case") + base_dir = os.path.join('/tmp', 'test_case') shutil.rmtree(base_dir, ignore_errors=True) os.mkdir(base_dir) - file_names = ["1.in", "1.out", "2.in", ".DS_Store"] + file_names = ['1.in', '1.out', '2.in', '.DS_Store'] for item in file_names: - with open(os.path.join(base_dir, item), "w", encoding="utf-8") as f: - f.write(item + "\n" + item + "\r\n" + "end") - zip_file = os.path.join(base_dir, "test_case.zip") - with ZipFile(os.path.join(base_dir, "test_case.zip"), "w") as f: + with open(os.path.join(base_dir, item), 'w', encoding='utf-8') as f: + f.write(item + '\n' + item + '\r\n' + 'end') + zip_file = os.path.join(base_dir, 'test_case.zip') + with ZipFile(os.path.join(base_dir, 'test_case.zip'), 'w') as f: for item in file_names: f.write(os.path.join(base_dir, item), item) return zip_file def test_upload_spj_test_case_zip(self): - with open(self.make_test_case_zip(), "rb") as f: - resp = self.client.post(self.url, - data={"spj": "true", "file": f}, format="multipart") + with open(self.make_test_case_zip(), 'rb') as f: + resp = self.client.post( + self.url, data={'spj': 'true', 'file': f}, format='multipart' + ) self.assertSuccess(resp) - data = resp.data["data"] - self.assertEqual(data["spj"], True) - test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"]) + data = resp.data['data'] + self.assertEqual(data['spj'], True) + test_case_dir = os.path.join(settings.TEST_CASE_DIR, data['id']) self.assertTrue(os.path.exists(test_case_dir)) - for item in data["info"]: - name = item["input_name"] - with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f: - self.assertEqual(f.read(), name + "\n" + name + "\n" + "end") + for item in data['info']: + name = item['input_name'] + with open( + os.path.join(test_case_dir, name), 'r', encoding='utf-8' + ) as f: + self.assertEqual(f.read(), name + '\n' + name + '\n' + 'end') def test_upload_test_case_zip(self): - with open(self.make_test_case_zip(), "rb") as f: - resp = self.client.post(self.url, - data={"spj": "false", "file": f}, format="multipart") + with open(self.make_test_case_zip(), 'rb') as f: + resp = self.client.post( + self.url, data={'spj': 'false', 'file': f}, format='multipart' + ) self.assertSuccess(resp) - data = resp.data["data"] - self.assertEqual(data["spj"], False) - test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"]) + data = resp.data['data'] + self.assertEqual(data['spj'], False) + test_case_dir = os.path.join(settings.TEST_CASE_DIR, data['id']) self.assertTrue(os.path.exists(test_case_dir)) - for item in data["info"]: - name = item["input_name"] - with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f: - self.assertEqual(f.read(), name + "\n" + name + "\n" + "end") + for item in data['info']: + name = item['input_name'] + with open( + os.path.join(test_case_dir, name), 'r', encoding='utf-8' + ) as f: + self.assertEqual(f.read(), name + '\n' + name + '\n' + 'end') class ProblemAdminAPITest(APITestCase): def setUp(self): - self.url = self.reverse("problem_admin_api") + self.url = self.reverse('problem_admin_api') self.create_super_admin() self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA) @@ -146,16 +188,16 @@ def test_duplicate_display_id(self): self.test_create_problem() resp = self.client.post(self.url, data=self.data) - self.assertFailed(resp, "Display ID already exists") + self.assertFailed(resp, 'Display ID already exists') def test_spj(self): data = copy.deepcopy(self.data) - data["spj"] = True + data['spj'] = True resp = self.client.post(self.url, data) - self.assertFailed(resp, "Invalid spj") + self.assertFailed(resp, 'Invalid spj') - data["spj_code"] = "test" + data['spj_code'] = 'test' resp = self.client.post(self.url, data=data) self.assertSuccess(resp) @@ -165,122 +207,128 @@ def test_get_problem(self): self.assertSuccess(resp) def test_get_one_problem(self): - problem_id = self.test_create_problem().data["data"]["id"] - resp = self.client.get(self.url + "?id=" + str(problem_id)) + problem_id = self.test_create_problem().data['data']['id'] + resp = self.client.get(self.url + '?id=' + str(problem_id)) self.assertSuccess(resp) def test_edit_problem(self): - problem_id = self.test_create_problem().data["data"]["id"] + problem_id = self.test_create_problem().data['data']['id'] data = copy.deepcopy(self.data) - data["id"] = problem_id + data['id'] = problem_id resp = self.client.put(self.url, data=data) self.assertSuccess(resp) class ProblemAPITest(ProblemCreateTestBase): def setUp(self): - self.url = self.reverse("problem_api") + self.url = self.reverse('problem_api') admin = self.create_admin(login=False) self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.create_user("test", "test123") + self.create_user('test', 'test123') def test_get_problem_list(self): - resp = self.client.get(f"{self.url}?limit=10") + resp = self.client.get(f'{self.url}?limit=10') self.assertSuccess(resp) def get_one_problem(self): - resp = self.client.get(self.url + "?id=" + self.problem._id) + resp = self.client.get(self.url + '?id=' + self.problem._id) self.assertSuccess(resp) class ContestProblemAdminTest(APITestCase): def setUp(self): - self.url = self.reverse("contest_problem_admin_api") + self.url = self.reverse('contest_problem_admin_api') self.create_admin() - self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"] + self.contest = self.client.post( + self.reverse('contest_admin_api'), data=DEFAULT_CONTEST_DATA + ).data['data'] def test_create_contest_problem(self): data = copy.deepcopy(DEFAULT_PROBLEM_DATA) - data["contest_id"] = self.contest["id"] + data['contest_id'] = self.contest['id'] resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - return resp.data["data"] + return resp.data['data'] def test_get_contest_problem(self): self.test_create_contest_problem() - contest_id = self.contest["id"] - resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) + contest_id = self.contest['id'] + resp = self.client.get(self.url + '?contest_id=' + str(contest_id)) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]["results"]), 1) + self.assertEqual(len(resp.data['data']['results']), 1) def test_get_one_contest_problem(self): contest_problem = self.test_create_contest_problem() - contest_id = self.contest["id"] - problem_id = contest_problem["id"] - resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") + contest_id = self.contest['id'] + problem_id = contest_problem['id'] + resp = self.client.get(f'{self.url}?contest_id={contest_id}&id={problem_id}') self.assertSuccess(resp) class ContestProblemTest(ProblemCreateTestBase): def setUp(self): admin = self.create_admin() - url = self.reverse("contest_admin_api") + url = self.reverse('contest_admin_api') contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) - contest_data["password"] = "" - contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) - self.contest = self.client.post(url, data=contest_data).data["data"] + contest_data['password'] = '' + contest_data['start_time'] = contest_data['start_time'] + timedelta(hours=1) + self.contest = self.client.post(url, data=contest_data).data['data'] self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.problem.contest_id = self.contest["id"] + self.problem.contest_id = self.contest['id'] self.problem.save() - self.url = self.reverse("contest_problem_api") + self.url = self.reverse('contest_problem_api') def test_admin_get_contest_problem_list(self): - contest_id = self.contest["id"] - resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) + contest_id = self.contest['id'] + resp = self.client.get(self.url + '?contest_id=' + str(contest_id)) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]), 1) + self.assertEqual(len(resp.data['data']), 1) def test_admin_get_one_contest_problem(self): - contest_id = self.contest["id"] + contest_id = self.contest['id'] problem_id = self.problem._id - resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id)) + resp = self.client.get( + '{}?contest_id={}&problem_id={}'.format(self.url, contest_id, problem_id) + ) self.assertSuccess(resp) def test_regular_user_get_not_started_contest_problem(self): - self.create_user("test", "test123") - resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"])) - self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."}) + self.create_user('test', 'test123') + resp = self.client.get(self.url + '?contest_id=' + str(self.contest['id'])) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Contest has not started yet.'} + ) def test_reguar_user_get_started_contest_problem(self): - self.create_user("test", "test123") + self.create_user('test', 'test123') contest = Contest.objects.first() contest.start_time = contest.start_time - timedelta(hours=1) contest.save() - resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"])) + resp = self.client.get(self.url + '?contest_id=' + str(self.contest['id'])) self.assertSuccess(resp) class AddProblemFromPublicProblemAPITest(ProblemCreateTestBase): def setUp(self): admin = self.create_admin() - url = self.reverse("contest_admin_api") + url = self.reverse('contest_admin_api') contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) - contest_data["password"] = "" - contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) - self.contest = self.client.post(url, data=contest_data).data["data"] + contest_data['password'] = '' + contest_data['start_time'] = contest_data['start_time'] + timedelta(hours=1) + self.contest = self.client.post(url, data=contest_data).data['data'] self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.url = self.reverse("add_contest_problem_from_public_api") + self.url = self.reverse('add_contest_problem_from_public_api') self.data = { - "display_id": "1000", - "contest_id": self.contest["id"], - "problem_id": self.problem.id + 'display_id': '1000', + 'contest_id': self.contest['id'], + 'problem_id': self.problem.id, } def test_add_contest_problem(self): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) self.assertTrue(Problem.objects.all().exists()) - self.assertTrue(Problem.objects.filter(contest_id=self.contest["id"]).exists()) + self.assertTrue(Problem.objects.filter(contest_id=self.contest['id']).exists()) class ParseProblemTemplateTest(APITestCase): @@ -300,9 +348,9 @@ def test_parse(self): """ ret = parse_problem_template(template_str) - self.assertEqual(ret["prepend"], "aaa\n") - self.assertEqual(ret["template"], "bbb\n") - self.assertEqual(ret["append"], "ccc\n") + self.assertEqual(ret['prepend'], 'aaa\n') + self.assertEqual(ret['template'], 'bbb\n') + self.assertEqual(ret['append'], 'ccc\n') def test_parse1(self): template_str = """ @@ -319,6 +367,6 @@ def test_parse1(self): """ ret = parse_problem_template(template_str) - self.assertEqual(ret["prepend"], "aaa\n") - self.assertEqual(ret["template"], "") - self.assertEqual(ret["append"], "ccc\n") + self.assertEqual(ret['prepend'], 'aaa\n') + self.assertEqual(ret['template'], '') + self.assertEqual(ret['append'], 'ccc\n') diff --git a/problem/urls/admin.py b/problem/urls/admin.py index e3a921fca..7e0552c7c 100644 --- a/problem/urls/admin.py +++ b/problem/urls/admin.py @@ -1,17 +1,37 @@ from django.conf.urls import url -from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView, - CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI, - FPSProblemImport) +from ..views.admin import ( + ContestProblemAPI, + ProblemAPI, + TestCaseAPI, + MakeContestProblemPublicAPIView, + CompileSPJAPI, + AddContestProblemAPI, + ExportProblemAPI, + ImportProblemAPI, + FPSProblemImport, +) urlpatterns = [ - url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"), - url(r"^compile_spj/?$", CompileSPJAPI.as_view(), name="compile_spj"), - url(r"^problem/?$", ProblemAPI.as_view(), name="problem_admin_api"), - url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"), - url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"), - url(r"^contest/add_problem_from_public/?$", AddContestProblemAPI.as_view(), name="add_contest_problem_from_public_api"), - url(r"^export_problem/?$", ExportProblemAPI.as_view(), name="export_problem_api"), - url(r"^import_problem/?$", ImportProblemAPI.as_view(), name="import_problem_api"), - url(r"^import_fps/?$", FPSProblemImport.as_view(), name="fps_problem_api"), + url(r'^test_case/?$', TestCaseAPI.as_view(), name='test_case_api'), + url(r'^compile_spj/?$', CompileSPJAPI.as_view(), name='compile_spj'), + url(r'^problem/?$', ProblemAPI.as_view(), name='problem_admin_api'), + url( + r'^contest/problem/?$', + ContestProblemAPI.as_view(), + name='contest_problem_admin_api', + ), + url( + r'^contest_problem/make_public/?$', + MakeContestProblemPublicAPIView.as_view(), + name='make_public_api', + ), + url( + r'^contest/add_problem_from_public/?$', + AddContestProblemAPI.as_view(), + name='add_contest_problem_from_public_api', + ), + url(r'^export_problem/?$', ExportProblemAPI.as_view(), name='export_problem_api'), + url(r'^import_problem/?$', ImportProblemAPI.as_view(), name='import_problem_api'), + url(r'^import_fps/?$', FPSProblemImport.as_view(), name='fps_problem_api'), ] diff --git a/problem/urls/oj.py b/problem/urls/oj.py index f7cd3ae3a..bae1f7cf6 100644 --- a/problem/urls/oj.py +++ b/problem/urls/oj.py @@ -3,8 +3,10 @@ from ..views.oj import ProblemTagAPI, ProblemAPI, ContestProblemAPI, PickOneAPI urlpatterns = [ - url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api"), - url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"), - url(r"^pickone/?$", PickOneAPI.as_view(), name="pick_one_api"), - url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api"), + url(r'^problem/tags/?$', ProblemTagAPI.as_view(), name='problem_tag_list_api'), + url(r'^problem/?$', ProblemAPI.as_view(), name='problem_api'), + url(r'^pickone/?$', PickOneAPI.as_view(), name='pick_one_api'), + url( + r'^contest/problem/?$', ContestProblemAPI.as_view(), name='contest_problem_api' + ), ] diff --git a/problem/utils.py b/problem/utils.py index c53026394..3ac0bb33d 100644 --- a/problem/utils.py +++ b/problem/utils.py @@ -17,12 +17,14 @@ @lru_cache(maxsize=100) def parse_problem_template(template_str): - prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) - template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str) - append = re.findall(r"//APPEND BEGIN\n([\s\S]+?)//APPEND END", template_str) - return {"prepend": prepend[0] if prepend else "", - "template": template[0] if template else "", - "append": append[0] if append else ""} + prepend = re.findall(r'//PREPEND BEGIN\n([\s\S]+?)//PREPEND END', template_str) + template = re.findall(r'//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END', template_str) + append = re.findall(r'//APPEND BEGIN\n([\s\S]+?)//APPEND END', template_str) + return { + 'prepend': prepend[0] if prepend else '', + 'template': template[0] if template else '', + 'append': append[0] if append else '', + } @lru_cache(maxsize=100) diff --git a/problem/views/admin.py b/problem/views/admin.py index 5ce9413b2..7f3988338 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -1,6 +1,7 @@ import hashlib import json import os + # import shutil import tempfile import zipfile @@ -22,25 +23,35 @@ from utils.shortcuts import rand_str, natural_sort_key from utils.tasks import delete_files from ..models import Problem, ProblemRuleType, ProblemTag -from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer, - CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, - ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, - AddContestProblemSerializer, ExportProblemSerializer, - ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer, - FPSProblemSerializer) +from ..serializers import ( + CreateContestProblemSerializer, + CompileSPJSerializer, + CreateProblemSerializer, + EditProblemSerializer, + EditContestProblemSerializer, + ProblemAdminSerializer, + TestCaseUploadForm, + ContestProblemMakePublicSerializer, + AddContestProblemSerializer, + ExportProblemSerializer, + ExportProblemRequestSerialzier, + UploadProblemForm, + ImportProblemSerializer, + FPSProblemSerializer, +) from ..utils import TEMPLATE_BASE, build_problem_template class TestCaseZipProcessor(object): - def process_zip(self, uploaded_zip_file, spj, dir=""): + def process_zip(self, uploaded_zip_file, spj, dir=''): try: - zip_file = zipfile.ZipFile(uploaded_zip_file, "r") + zip_file = zipfile.ZipFile(uploaded_zip_file, 'r') except zipfile.BadZipFile: - raise APIError("Bad zip file") + raise APIError('Bad zip file') name_list = zip_file.namelist() test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir) if not test_case_list: - raise APIError("Empty file") + raise APIError('Empty file') test_case_id = rand_str() test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) @@ -51,34 +62,36 @@ def process_zip(self, uploaded_zip_file, spj, dir=""): md5_cache = {} for item in test_case_list: - with open(os.path.join(test_case_dir, item), "wb") as f: - content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n") + with open(os.path.join(test_case_dir, item), 'wb') as f: + content = zip_file.read(f'{dir}{item}').replace(b'\r\n', b'\n') size_cache[item] = len(content) - if item.endswith(".out"): + if item.endswith('.out'): md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() f.write(content) - test_case_info = {"spj": spj, "test_cases": {}} + test_case_info = {'spj': spj, 'test_cases': {}} info = [] if spj: for index, item in enumerate(test_case_list): - data = {"input_name": item, "input_size": size_cache[item]} + data = {'input_name': item, 'input_size': size_cache[item]} info.append(data) - test_case_info["test_cases"][str(index + 1)] = data + test_case_info['test_cases'][str(index + 1)] = data else: # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list): - data = {"stripped_output_md5": md5_cache[item[1]], - "input_size": size_cache[item[0]], - "output_size": size_cache[item[1]], - "input_name": item[0], - "output_name": item[1]} + data = { + 'stripped_output_md5': md5_cache[item[1]], + 'input_size': size_cache[item[0]], + 'output_size': size_cache[item[1]], + 'input_name': item[0], + 'output_name': item[1], + } info.append(data) - test_case_info["test_cases"][str(index + 1)] = data + test_case_info['test_cases'][str(index + 1)] = data - with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: + with open(os.path.join(test_case_dir, 'info'), 'w', encoding='utf-8') as f: f.write(json.dumps(test_case_info, indent=4)) for item in os.listdir(test_case_dir): @@ -86,13 +99,13 @@ def process_zip(self, uploaded_zip_file, spj, dir=""): return info, test_case_id - def filter_name_list(self, name_list, spj, dir=""): + def filter_name_list(self, name_list, spj, dir=''): ret = [] prefix = 1 if spj: while True: - in_name = f"{prefix}.in" - if f"{dir}{in_name}" in name_list: + in_name = f'{prefix}.in' + if f'{dir}{in_name}' in name_list: ret.append(in_name) prefix += 1 continue @@ -100,9 +113,9 @@ def filter_name_list(self, name_list, spj, dir=""): return sorted(ret, key=natural_sort_key) else: while True: - in_name = f"{prefix}.in" - out_name = f"{prefix}.out" - if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list: + in_name = f'{prefix}.in' + out_name = f'{prefix}.out' + if f'{dir}{in_name}' in name_list and f'{dir}{out_name}' in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 @@ -115,13 +128,13 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor): request_parsers = () def get(self, request): - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if not problem_id: - return self.error("Parameter error, problem_id is required") + return self.error('Parameter error, problem_id is required') try: problem = Problem.objects.get(id=problem_id) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') if problem.contest: ensure_created_by(problem.contest, request.user) @@ -130,34 +143,37 @@ def get(self, request): test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): - return self.error("Test case does not exists") + return self.error('Test case does not exists') name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj) - name_list.append("info") - file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") - with zipfile.ZipFile(file_name, "w") as file: + name_list.append('info') + file_name = os.path.join(test_case_dir, problem.test_case_id + '.zip') + with zipfile.ZipFile(file_name, 'w') as file: for test_case in name_list: - file.write(f"{test_case_dir}/{test_case}", test_case) - response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), - content_type="application/octet-stream") - - response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip" - response["Content-Length"] = os.path.getsize(file_name) + file.write(f'{test_case_dir}/{test_case}', test_case) + response = StreamingHttpResponse( + FileWrapper(open(file_name, 'rb')), content_type='application/octet-stream' + ) + + response[ + 'Content-Disposition' + ] = f'attachment; filename=problem_{problem.id}_test_cases.zip' + response['Content-Length'] = os.path.getsize(file_name) return response def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): - spj = form.cleaned_data["spj"] == "true" - file = form.cleaned_data["file"] + spj = form.cleaned_data['spj'] == 'true' + file = form.cleaned_data['file'] else: - return self.error("Upload failed") - zip_file = f"/tmp/{rand_str()}.zip" - with open(zip_file, "wb") as f: + return self.error('Upload failed') + zip_file = f'/tmp/{rand_str()}.zip' + with open(zip_file, 'wb') as f: for chunk in file: f.write(chunk) info, test_case_id = self.process_zip(zip_file, spj=spj) os.remove(zip_file) - return self.success({"id": test_case_id, "info": info, "spj": spj}) + return self.success({'id': test_case_id, 'info': info, 'spj': spj}) class CompileSPJAPI(APIView): @@ -165,7 +181,9 @@ class CompileSPJAPI(APIView): def post(self, request): data = request.data spj_version = rand_str(8) - error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj() + error = SPJCompiler( + data['spj_code'], spj_version, data['spj_language'] + ).compile_spj() if error: return self.error(error) else: @@ -175,25 +193,26 @@ def post(self, request): class ProblemBase(APIView): def common_checks(self, request): data = request.data - if data["spj"]: - if not data["spj_language"] or not data["spj_code"]: - return "Invalid spj" - if not data["spj_compile_ok"]: - return "SPJ code must be compiled successfully" - data["spj_version"] = hashlib.md5( - (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + if data['spj']: + if not data['spj_language'] or not data['spj_code']: + return 'Invalid spj' + if not data['spj_compile_ok']: + return 'SPJ code must be compiled successfully' + data['spj_version'] = hashlib.md5( + (data['spj_language'] + ':' + data['spj_code']).encode('utf-8') + ).hexdigest() else: - data["spj_language"] = None - data["spj_code"] = None - if data["rule_type"] == ProblemRuleType.OI: + data['spj_language'] = None + data['spj_code'] = None + if data['rule_type'] == ProblemRuleType.OI: total_score = 0 - for item in data["test_case_score"]: - if item["score"] <= 0: - return "Invalid score" + for item in data['test_case_score']: + if item['score'] <= 0: + return 'Invalid score' else: - total_score += item["score"] - data["total_score"] = total_score - data["languages"] = list(data["languages"]) + total_score += item['score'] + data['total_score'] = total_score + data['languages'] = list(data['languages']) class ProblemAPI(ProblemBase): @@ -201,19 +220,19 @@ class ProblemAPI(ProblemBase): @validate_serializer(CreateProblemSerializer) def post(self, request): data = request.data - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") + return self.error('Display ID is required') if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists(): - return self.error("Display ID already exists") + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["created_by"] = request.user + tags = data.pop('tags') + data['created_by'] = request.user problem = Problem.objects.create(**data) for item in tags: @@ -226,8 +245,8 @@ def post(self, request): @problem_permission_required def get(self, request): - problem_id = request.GET.get("id") - rule_type = request.GET.get("rule_type") + problem_id = request.GET.get('id') + rule_type = request.GET.get('rule_type') user = request.user if problem_id: try: @@ -235,46 +254,56 @@ def get(self, request): ensure_created_by(problem, request.user) return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time") + problems = Problem.objects.filter(contest_id__isnull=True).order_by( + '-create_time' + ) if rule_type: if rule_type not in ProblemRuleType.choices(): - return self.error("Invalid rule_type") + return self.error('Invalid rule_type') else: problems = problems.filter(rule_type=rule_type) - keyword = request.GET.get("keyword", "").strip() + keyword = request.GET.get('keyword', '').strip() if keyword: - problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword)) + problems = problems.filter( + Q(title__icontains=keyword) | Q(_id__icontains=keyword) + ) if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=user) - return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) + return self.success( + self.paginate_data(request, problems, ProblemAdminSerializer) + ) @problem_permission_required @validate_serializer(EditProblemSerializer) def put(self, request): data = request.data - problem_id = data.pop("id") + problem_id = data.pop('id') try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") - if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists(): - return self.error("Display ID already exists") + return self.error('Display ID is required') + if ( + Problem.objects.exclude(id=problem_id) + .filter(_id=_id, contest_id__isnull=True) + .exists() + ): + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["languages"] = list(data["languages"]) + tags = data.pop('tags') + data['languages'] = list(data['languages']) for k, v in data.items(): setattr(problem, k, v) @@ -292,13 +321,13 @@ def put(self, request): @problem_permission_required def delete(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Invalid parameter, id is required") + return self.error('Invalid parameter, id is required') try: problem = Problem.objects.get(id=id, contest_id__isnull=True) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') ensure_created_by(problem, request.user) # d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) # if os.path.isdir(d): @@ -312,29 +341,29 @@ class ContestProblemAPI(ProblemBase): def post(self, request): data = request.data try: - contest = Contest.objects.get(id=data.pop("contest_id")) + contest = Contest.objects.get(id=data.pop('contest_id')) ensure_created_by(contest, request.user) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - if data["rule_type"] != contest.rule_type: - return self.error("Invalid rule type") + if data['rule_type'] != contest.rule_type: + return self.error('Invalid rule type') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") + return self.error('Display ID is required') if Problem.objects.filter(_id=_id, contest=contest).exists(): - return self.error("Duplicate Display id") + return self.error('Duplicate Display id') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - data["contest"] = contest - tags = data.pop("tags") - data["created_by"] = request.user + data['contest'] = contest + tags = data.pop('tags') + data['created_by'] = request.user problem = Problem.objects.create(**data) for item in tags: @@ -346,31 +375,33 @@ def post(self, request): return self.success(ProblemAdminSerializer(problem).data) def get(self, request): - problem_id = request.GET.get("id") - contest_id = request.GET.get("contest_id") + problem_id = request.GET.get('id') + contest_id = request.GET.get('contest_id') user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem.contest, user) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') return self.success(ProblemAdminSerializer(problem).data) if not contest_id: - return self.error("Contest id is required") + return self.error('Contest id is required') try: contest = Contest.objects.get(id=contest_id) ensure_created_by(contest, user) except Contest.DoesNotExist: - return self.error("Contest does not exist") - problems = Problem.objects.filter(contest=contest).order_by("-create_time") + return self.error('Contest does not exist') + problems = Problem.objects.filter(contest=contest).order_by('-create_time') if user.is_admin(): problems = problems.filter(contest__created_by=user) - keyword = request.GET.get("keyword") + keyword = request.GET.get('keyword') if keyword: problems = problems.filter(title__contains=keyword) - return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) + return self.success( + self.paginate_data(request, problems, ProblemAdminSerializer) + ) @validate_serializer(EditContestProblemSerializer) def put(self, request): @@ -378,33 +409,37 @@ def put(self, request): user = request.user try: - contest = Contest.objects.get(id=data.pop("contest_id")) + contest = Contest.objects.get(id=data.pop('contest_id')) ensure_created_by(contest, user) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - if data["rule_type"] != contest.rule_type: - return self.error("Invalid rule type") + if data['rule_type'] != contest.rule_type: + return self.error('Invalid rule type') - problem_id = data.pop("id") + problem_id = data.pop('id') try: problem = Problem.objects.get(id=problem_id, contest=contest) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") - if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists(): - return self.error("Display ID already exists") + return self.error('Display ID is required') + if ( + Problem.objects.exclude(id=problem_id) + .filter(_id=_id, contest=contest) + .exists() + ): + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["languages"] = list(data["languages"]) + tags = data.pop('tags') + data['languages'] = list(data['languages']) for k, v in data.items(): setattr(problem, k, v) @@ -420,13 +455,13 @@ def put(self, request): return self.success() def delete(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Invalid parameter, id is required") + return self.error('Invalid parameter, id is required') try: problem = Problem.objects.get(id=id, contest_id__isnull=False) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') ensure_created_by(problem.contest, request.user) if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") @@ -442,17 +477,17 @@ class MakeContestProblemPublicAPIView(APIView): @problem_permission_required def post(self, request): data = request.data - display_id = data.get("display_id") + display_id = data.get('display_id') if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists(): - return self.error("Duplicate display ID") + return self.error('Duplicate display ID') try: - problem = Problem.objects.get(id=data["id"]) + problem = Problem.objects.get(id=data['id']) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') if not problem.contest or problem.is_public: - return self.error("Already be a public problem") + return self.error('Already be a public problem') problem.is_public = True problem.save() # https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances @@ -473,22 +508,22 @@ class AddContestProblemAPI(APIView): def post(self, request): data = request.data try: - contest = Contest.objects.get(id=data["contest_id"]) - problem = Problem.objects.get(id=data["problem_id"]) + contest = Contest.objects.get(id=data['contest_id']) + problem = Problem.objects.get(id=data['problem_id']) except (Contest.DoesNotExist, Problem.DoesNotExist): - return self.error("Contest or Problem does not exist") + return self.error('Contest or Problem does not exist') if contest.status == ContestStatus.CONTEST_ENDED: - return self.error("Contest has ended") - if Problem.objects.filter(contest=contest, _id=data["display_id"]).exists(): - return self.error("Duplicate display id in this contest") + return self.error('Contest has ended') + if Problem.objects.filter(contest=contest, _id=data['display_id']).exists(): + return self.error('Duplicate display id in this contest') tags = problem.tags.all() problem.pk = None problem.contest = contest problem.is_public = True problem.visible = True - problem._id = request.data["display_id"] + problem._id = request.data['display_id'] problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() @@ -500,49 +535,68 @@ class ExportProblemAPI(APIView): def choose_answers(self, user, problem): ret = [] for item in problem.languages: - submission = Submission.objects.filter(problem=problem, - user_id=user.id, - language=item, - result=JudgeStatus.ACCEPTED).order_by("-create_time").first() + submission = ( + Submission.objects.filter( + problem=problem, + user_id=user.id, + language=item, + result=JudgeStatus.ACCEPTED, + ) + .order_by('-create_time') + .first() + ) if submission: - ret.append({"language": submission.language, "code": submission.code}) + ret.append({'language': submission.language, 'code': submission.code}) return ret def process_one_problem(self, zip_file, user, problem, index): info = ExportProblemSerializer(problem).data - info["answers"] = self.choose_answers(user, problem=problem) + info['answers'] = self.choose_answers(user, problem=problem) compression = zipfile.ZIP_DEFLATED - zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json", - data=json.dumps(info, indent=4), - compress_type=compression) - problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) - with open(os.path.join(problem_test_case_dir, "info")) as f: + zip_file.writestr( + zinfo_or_arcname=f'{index}/problem.json', + data=json.dumps(info, indent=4), + compress_type=compression, + ) + problem_test_case_dir = os.path.join( + settings.TEST_CASE_DIR, problem.test_case_id + ) + with open(os.path.join(problem_test_case_dir, 'info')) as f: info = json.load(f) - for k, v in info["test_cases"].items(): - zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]), - arcname=f"{index}/testcase/{v['input_name']}", - compress_type=compression) - if not info["spj"]: - zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]), - arcname=f"{index}/testcase/{v['output_name']}", - compress_type=compression) + for k, v in info['test_cases'].items(): + zip_file.write( + filename=os.path.join(problem_test_case_dir, v['input_name']), + arcname=f"{index}/testcase/{v['input_name']}", + compress_type=compression, + ) + if not info['spj']: + zip_file.write( + filename=os.path.join(problem_test_case_dir, v['output_name']), + arcname=f"{index}/testcase/{v['output_name']}", + compress_type=compression, + ) @validate_serializer(ExportProblemRequestSerialzier) def get(self, request): - problems = Problem.objects.filter(id__in=request.data["problem_id"]) + problems = Problem.objects.filter(id__in=request.data['problem_id']) for problem in problems: if problem.contest: ensure_created_by(problem.contest, request.user) else: ensure_created_by(problem, request.user) - path = f"/tmp/{rand_str()}.zip" - with zipfile.ZipFile(path, "w") as zip_file: + path = f'/tmp/{rand_str()}.zip' + with zipfile.ZipFile(path, 'w') as zip_file: for index, problem in enumerate(problems): - self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1) + self.process_one_problem( + zip_file=zip_file, + user=request.user, + problem=problem, + index=index + 1, + ) delete_files.send_with_options(args=(path,), delay=300_000) - resp = FileResponse(open(path, "rb")) - resp["Content-Type"] = "application/zip" - resp["Content-Disposition"] = "attachment;filename=problem-export.zip" + resp = FileResponse(open(path, 'rb')) + resp['Content-Type'] = 'application/zip' + resp['Content-Disposition'] = 'attachment;filename=problem-export.zip' return resp @@ -552,128 +606,142 @@ class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor): def post(self, request): form = UploadProblemForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] - tmp_file = f"/tmp/{rand_str()}.zip" - with open(tmp_file, "wb") as f: + file = form.cleaned_data['file'] + tmp_file = f'/tmp/{rand_str()}.zip' + with open(tmp_file, 'wb') as f: for chunk in file: f.write(chunk) else: - return self.error("Upload failed") + return self.error('Upload failed') count = 0 - with zipfile.ZipFile(tmp_file, "r") as zip_file: + with zipfile.ZipFile(tmp_file, 'r') as zip_file: name_list = zip_file.namelist() for item in name_list: - if "/problem.json" in item: + if '/problem.json' in item: count += 1 with transaction.atomic(): for i in range(1, count + 1): - with zip_file.open(f"{i}/problem.json") as f: + with zip_file.open(f'{i}/problem.json') as f: problem_info = json.load(f) serializer = ImportProblemSerializer(data=problem_info) if not serializer.is_valid(): - return self.error(f"Invalid problem format, error is {serializer.errors}") + return self.error( + f'Invalid problem format, error is {serializer.errors}' + ) else: problem_info = serializer.data - for item in problem_info["template"].keys(): + for item in problem_info['template'].keys(): if item not in SysOptions.language_names: - return self.error(f"Unsupported language {item}") + return self.error(f'Unsupported language {item}') - problem_info["display_id"] = problem_info["display_id"][:24] - for k, v in problem_info["template"].items(): - problem_info["template"][k] = build_problem_template(v["prepend"], v["template"], - v["append"]) + problem_info['display_id'] = problem_info['display_id'][:24] + for k, v in problem_info['template'].items(): + problem_info['template'][k] = build_problem_template( + v['prepend'], v['template'], v['append'] + ) - spj = problem_info["spj"] is not None - rule_type = problem_info["rule_type"] - test_case_score = problem_info["test_case_score"] + spj = problem_info['spj'] is not None + rule_type = problem_info['rule_type'] + test_case_score = problem_info['test_case_score'] # process test case - _, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/") - - problem_obj = Problem.objects.create(_id=problem_info["display_id"], - title=problem_info["title"], - description=problem_info["description"]["value"], - input_description=problem_info["input_description"][ - "value"], - output_description=problem_info["output_description"][ - "value"], - hint=problem_info["hint"]["value"], - test_case_score=test_case_score if test_case_score else [], - time_limit=problem_info["time_limit"], - memory_limit=problem_info["memory_limit"], - samples=problem_info["samples"], - template=problem_info["template"], - rule_type=problem_info["rule_type"], - source=problem_info["source"], - spj=spj, - spj_code=problem_info["spj"]["code"] if spj else None, - spj_language=problem_info["spj"][ - "language"] if spj else None, - spj_version=rand_str(8) if spj else "", - languages=SysOptions.language_names, - created_by=request.user, - visible=False, - difficulty=Difficulty.MID, - total_score=sum(item["score"] for item in test_case_score) - if rule_type == ProblemRuleType.OI else 0, - test_case_id=test_case_id - ) - for tag_name in problem_info["tags"]: + _, test_case_id = self.process_zip( + tmp_file, spj=spj, dir=f'{i}/testcase/' + ) + + problem_obj = Problem.objects.create( + _id=problem_info['display_id'], + title=problem_info['title'], + description=problem_info['description']['value'], + input_description=problem_info['input_description'][ + 'value' + ], + output_description=problem_info['output_description'][ + 'value' + ], + hint=problem_info['hint']['value'], + test_case_score=test_case_score if test_case_score else [], + time_limit=problem_info['time_limit'], + memory_limit=problem_info['memory_limit'], + samples=problem_info['samples'], + template=problem_info['template'], + rule_type=problem_info['rule_type'], + source=problem_info['source'], + spj=spj, + spj_code=problem_info['spj']['code'] if spj else None, + spj_language=problem_info['spj']['language'] + if spj + else None, + spj_version=rand_str(8) if spj else '', + languages=SysOptions.language_names, + created_by=request.user, + visible=False, + difficulty=Difficulty.MID, + total_score=sum(item['score'] for item in test_case_score) + if rule_type == ProblemRuleType.OI + else 0, + test_case_id=test_case_id, + ) + for tag_name in problem_info['tags']: tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name) problem_obj.tags.add(tag_obj) - return self.success({"import_count": count}) + return self.success({'import_count': count}) class FPSProblemImport(CSRFExemptAPIView): request_parsers = () def _create_problem(self, problem_data, creator): - if problem_data["time_limit"]["unit"] == "ms": - time_limit = problem_data["time_limit"]["value"] + if problem_data['time_limit']['unit'] == 'ms': + time_limit = problem_data['time_limit']['value'] else: - time_limit = problem_data["time_limit"]["value"] * 1000 + time_limit = problem_data['time_limit']['value'] * 1000 template = {} prepend = {} append = {} - for t in problem_data["prepend"]: - prepend[t["language"]] = t["code"] - for t in problem_data["append"]: - append[t["language"]] = t["code"] - for t in problem_data["template"]: - our_lang = lang = t["language"] - if lang == "Python": - our_lang = "Python3" - template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, "")) - spj = problem_data["spj"] is not None - Problem.objects.create(_id=f"fps-{rand_str(4)}", - title=problem_data["title"], - description=problem_data["description"], - input_description=problem_data["input"], - output_description=problem_data["output"], - hint=problem_data["hint"], - test_case_score=problem_data["test_case_score"], - time_limit=time_limit, - memory_limit=problem_data["memory_limit"]["value"], - samples=problem_data["samples"], - template=template, - rule_type=ProblemRuleType.ACM, - source=problem_data.get("source", ""), - spj=spj, - spj_code=problem_data["spj"]["code"] if spj else None, - spj_language=problem_data["spj"]["language"] if spj else None, - spj_version=rand_str(8) if spj else "", - visible=False, - languages=SysOptions.language_names, - created_by=creator, - difficulty=Difficulty.MID, - test_case_id=problem_data["test_case_id"]) + for t in problem_data['prepend']: + prepend[t['language']] = t['code'] + for t in problem_data['append']: + append[t['language']] = t['code'] + for t in problem_data['template']: + our_lang = lang = t['language'] + if lang == 'Python': + our_lang = 'Python3' + template[our_lang] = TEMPLATE_BASE.format( + prepend.get(lang, ''), t['code'], append.get(lang, '') + ) + spj = problem_data['spj'] is not None + Problem.objects.create( + _id=f'fps-{rand_str(4)}', + title=problem_data['title'], + description=problem_data['description'], + input_description=problem_data['input'], + output_description=problem_data['output'], + hint=problem_data['hint'], + test_case_score=problem_data['test_case_score'], + time_limit=time_limit, + memory_limit=problem_data['memory_limit']['value'], + samples=problem_data['samples'], + template=template, + rule_type=ProblemRuleType.ACM, + source=problem_data.get('source', ''), + spj=spj, + spj_code=problem_data['spj']['code'] if spj else None, + spj_language=problem_data['spj']['language'] if spj else None, + spj_version=rand_str(8) if spj else '', + visible=False, + languages=SysOptions.language_names, + created_by=creator, + difficulty=Difficulty.MID, + test_case_id=problem_data['test_case_id'], + ) def post(self, request): form = UploadProblemForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] - with tempfile.NamedTemporaryFile("wb") as tf: + file = form.cleaned_data['file'] + with tempfile.NamedTemporaryFile('wb') as tf: for chunk in file.chunks(4096): tf.file.write(chunk) @@ -682,7 +750,7 @@ def post(self, request): problems = FPSParser(tf.name).parse() else: - return self.error("Parse upload file error") + return self.error('Parse upload file error') helper = FPSHelper() with transaction.atomic(): @@ -691,15 +759,24 @@ def post(self, request): test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) os.mkdir(test_case_dir) score = [] - for item in helper.save_test_case(_problem, test_case_dir)["test_cases"].values(): - score.append({"score": 0, "input_name": item["input_name"], - "output_name": item.get("output_name")}) - problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX) + for item in helper.save_test_case(_problem, test_case_dir)[ + 'test_cases' + ].values(): + score.append( + { + 'score': 0, + 'input_name': item['input_name'], + 'output_name': item.get('output_name'), + } + ) + problem_data = helper.save_image( + _problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX + ) s = FPSProblemSerializer(data=problem_data) if not s.is_valid(): - return self.error(f"Parse FPS file error: {s.errors}") + return self.error(f'Parse FPS file error: {s.errors}') problem_data = s.data - problem_data["test_case_id"] = test_case_id - problem_data["test_case_score"] = score + problem_data['test_case_id'] = test_case_id + problem_data['test_case_score'] = score self._create_problem(problem_data, request.user) - return self.success({"import_count": len(problems)}) + return self.success({'import_count': len(problems)}) diff --git a/problem/views/oj.py b/problem/views/oj.py index 534fa6059..0ecd50dbf 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -10,10 +10,10 @@ class ProblemTagAPI(APIView): def get(self, request): qs = ProblemTag.objects - keyword = request.GET.get("keyword") + keyword = request.GET.get('keyword') if keyword: qs = ProblemTag.objects.filter(name__icontains=keyword) - tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0) + tags = qs.annotate(problem_count=Count('problem')).filter(problem_count__gt=0) return self.success(TagSerializer(tags, many=True).data) @@ -22,7 +22,7 @@ def get(self, request): problems = Problem.objects.filter(contest_id__isnull=True, visible=True) count = problems.count() if count == 0: - return self.error("No problem to pick") + return self.error('No problem to pick') return self.success(problems[random.randint(0, count - 1)]._id) @@ -31,50 +31,61 @@ class ProblemAPI(APIView): def _add_problem_status(request, queryset_values): if request.user.is_authenticated: profile = request.user.userprofile - acm_problems_status = profile.acm_problems_status.get("problems", {}) - oi_problems_status = profile.oi_problems_status.get("problems", {}) + acm_problems_status = profile.acm_problems_status.get('problems', {}) + oi_problems_status = profile.oi_problems_status.get('problems', {}) # paginate data - results = queryset_values.get("results") + results = queryset_values.get('results') if results is not None: problems = results else: - problems = [queryset_values, ] + problems = [ + queryset_values, + ] for problem in problems: - if problem["rule_type"] == ProblemRuleType.ACM: - problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") + if problem['rule_type'] == ProblemRuleType.ACM: + problem['my_status'] = acm_problems_status.get( + str(problem['id']), {} + ).get('status') else: - problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") + problem['my_status'] = oi_problems_status.get( + str(problem['id']), {} + ).get('status') def get(self, request): # 问题详情页 - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if problem_id: try: - problem = Problem.objects.select_related("created_by") \ - .get(_id=problem_id, contest_id__isnull=True, visible=True) + problem = Problem.objects.select_related('created_by').get( + _id=problem_id, contest_id__isnull=True, visible=True + ) problem_data = ProblemSerializer(problem).data self._add_problem_status(request, problem_data) return self.success(problem_data) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - limit = request.GET.get("limit") + limit = request.GET.get('limit') if not limit: - return self.error("Limit is needed") + return self.error('Limit is needed') - problems = Problem.objects.select_related("created_by").filter(contest_id__isnull=True, visible=True) + problems = Problem.objects.select_related('created_by').filter( + contest_id__isnull=True, visible=True + ) # 按照标签筛选 - tag_text = request.GET.get("tag") + tag_text = request.GET.get('tag') if tag_text: problems = problems.filter(tags__name=tag_text) # 搜索的情况 - keyword = request.GET.get("keyword", "").strip() + keyword = request.GET.get('keyword', '').strip() if keyword: - problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword)) + problems = problems.filter( + Q(title__icontains=keyword) | Q(_id__icontains=keyword) + ) # 难度筛选 - difficulty = request.GET.get("difficulty") + difficulty = request.GET.get('difficulty') if difficulty: problems = problems.filter(difficulty=difficulty) # 根据profile 为做过的题目添加标记 @@ -88,30 +99,41 @@ def _add_problem_status(self, request, queryset_values): if request.user.is_authenticated: profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: - problems_status = profile.acm_problems_status.get("contest_problems", {}) + problems_status = profile.acm_problems_status.get( + 'contest_problems', {} + ) else: - problems_status = profile.oi_problems_status.get("contest_problems", {}) + problems_status = profile.oi_problems_status.get('contest_problems', {}) for problem in queryset_values: - problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") + problem['my_status'] = problems_status.get(str(problem['id']), {}).get( + 'status' + ) - @check_contest_permission(check_type="problems") + @check_contest_permission(check_type='problems') def get(self, request): - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if problem_id: try: - problem = Problem.objects.select_related("created_by").get(_id=problem_id, - contest=self.contest, - visible=True) + problem = Problem.objects.select_related('created_by').get( + _id=problem_id, contest=self.contest, visible=True + ) except Problem.DoesNotExist: - return self.error("Problem does not exist.") + return self.error('Problem does not exist.') if self.contest.problem_details_permission(request.user): problem_data = ProblemSerializer(problem).data - self._add_problem_status(request, [problem_data, ]) + self._add_problem_status( + request, + [ + problem_data, + ], + ) else: problem_data = ProblemSafeSerializer(problem).data return self.success(problem_data) - contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) + contest_problems = Problem.objects.select_related('created_by').filter( + contest=self.contest, visible=True + ) if self.contest.problem_details_permission(request.user): data = ProblemSerializer(contest_problems, many=True).data self._add_problem_status(request, data) diff --git a/run_test.py b/run_test.py index a6c714f6e..627781742 100644 --- a/run_test.py +++ b/run_test.py @@ -2,26 +2,30 @@ import os import sys -opts, args = getopt.getopt(sys.argv[1:], "cm:", ["coverage=", "module="]) +opts, args = getopt.getopt(sys.argv[1:], 'cm:', ['coverage=', 'module=']) is_coverage = False -test_module = "" -setting = "oj.settings" +test_module = '' +setting = 'oj.settings' for opt, arg in opts: - if opt in ["-c", "--coverage"]: + if opt in ['-c', '--coverage']: is_coverage = True - if opt in ["-m", "--module"]: + if opt in ['-m', '--module']: test_module = arg -print("Coverage: {cov}".format(cov=is_coverage)) -print("Module: {mod}".format(mod=(test_module if test_module else "All"))) +print('Coverage: {cov}'.format(cov=is_coverage)) +print('Module: {mod}'.format(mod=(test_module if test_module else 'All'))) -print("running flake8...") -if os.system("flake8 --statistics ."): +print('running flake8...') +if os.system('flake8 --statistics .'): exit() -ret = os.system('coverage run --include="$PWD/*" manage.py test {module} --settings={setting}'.format(module=test_module, setting=setting)) +ret = os.system( + 'coverage run --include="$PWD/*" manage.py test {module} --settings={setting}'.format( + module=test_module, setting=setting + ) +) if not ret and is_coverage: - os.system("coverage html && open htmlcov/index.html") + os.system('coverage html && open htmlcov/index.html') diff --git a/submission/migrations/0001_initial.py b/submission/migrations/0001_initial.py index 42a5352f2..63ceb7c34 100644 --- a/submission/migrations/0001_initial.py +++ b/submission/migrations/0001_initial.py @@ -9,17 +9,24 @@ class Migration(migrations.Migration): - initial = True - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( name='Submission', fields=[ - ('id', models.CharField(db_index=True, default=utils.shortcuts.rand_str, max_length=32, primary_key=True, serialize=False)), + ( + 'id', + models.CharField( + db_index=True, + default=utils.shortcuts.rand_str, + max_length=32, + primary_key=True, + serialize=False, + ), + ), ('contest_id', models.IntegerField(db_index=True, null=True)), ('problem_id', models.IntegerField(db_index=True)), ('created_time', models.DateTimeField(auto_now_add=True)), diff --git a/submission/migrations/0002_auto_20170509_1203.py b/submission/migrations/0002_auto_20170509_1203.py index 78dcbe9bc..c5642f19d 100644 --- a/submission/migrations/0002_auto_20170509_1203.py +++ b/submission/migrations/0002_auto_20170509_1203.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0001_initial'), ] @@ -34,5 +33,5 @@ class Migration(migrations.Migration): migrations.AlterModelOptions( name='submission', options={'ordering': ('-create_time',)}, - ) + ), ] diff --git a/submission/migrations/0005_submission_username.py b/submission/migrations/0005_submission_username.py index 68a324357..24c341a5b 100644 --- a/submission/migrations/0005_submission_username.py +++ b/submission/migrations/0005_submission_username.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0002_auto_20170509_1203'), ] @@ -15,7 +14,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='submission', name='username', - field=models.CharField(default="", max_length=30), + field=models.CharField(default='', max_length=30), preserve_default=False, ), ] diff --git a/submission/migrations/0006_auto_20170830_1154.py b/submission/migrations/0006_auto_20170830_1154.py index 675cc8659..b492924f8 100644 --- a/submission/migrations/0006_auto_20170830_1154.py +++ b/submission/migrations/0006_auto_20170830_1154.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0005_submission_username'), ] diff --git a/submission/migrations/0007_auto_20170923_1318.py b/submission/migrations/0007_auto_20170923_1318.py index 635668074..6b8059d52 100644 --- a/submission/migrations/0007_auto_20170923_1318.py +++ b/submission/migrations/0007_auto_20170923_1318.py @@ -8,7 +8,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0006_auto_20170830_1154'), ] @@ -17,12 +16,18 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='submission', name='contest_id', - field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'), + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), ), migrations.AlterField( model_name='submission', name='problem_id', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.Problem'), + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='problem.Problem' + ), ), migrations.RenameField( model_name='submission', diff --git a/submission/migrations/0008_submission_ip.py b/submission/migrations/0008_submission_ip.py index e60841bdc..44f2de468 100644 --- a/submission/migrations/0008_submission_ip.py +++ b/submission/migrations/0008_submission_ip.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0007_auto_20170923_1318'), ] diff --git a/submission/migrations/0009_delete_user_output.py b/submission/migrations/0009_delete_user_output.py index 66c90c41c..26d1843c8 100644 --- a/submission/migrations/0009_delete_user_output.py +++ b/submission/migrations/0009_delete_user_output.py @@ -5,16 +5,15 @@ def delete_user_output(apps, schema_editor): - Submission = apps.get_model("submission", "Submission") + Submission = apps.get_model('submission', 'Submission') for item in Submission.objects.all(): - if "data" in item.info and isinstance(item.info["data"], list): - for index in range(len(item.info["data"])): - item.info["data"][index]["output"] = "" + if 'data' in item.info and isinstance(item.info['data'], list): + for index in range(len(item.info['data'])): + item.info['data'][index]['output'] = '' item.save() class Migration(migrations.Migration): - dependencies = [ ('submission', '0008_submission_ip'), ] diff --git a/submission/migrations/0011_fix_submission_number.py b/submission/migrations/0011_fix_submission_number.py index f3d5e6b76..3879f0624 100644 --- a/submission/migrations/0011_fix_submission_number.py +++ b/submission/migrations/0011_fix_submission_number.py @@ -5,15 +5,15 @@ def fix_rejudge_bugs(apps, schema_editor): - Submission = apps.get_model("submission", "Submission") - User = apps.get_model("account", "User") + Submission = apps.get_model('submission', 'Submission') + User = apps.get_model('account', 'User') for user in User.objects.all(): submissions = Submission.objects.filter(user_id=user.id, contest__isnull=True) profile = user.userprofile profile.submission_number = submissions.count() profile.accepted_number = submissions.filter(result=0).count() - profile.save(update_fields=["submission_number", "accepted_number"]) + profile.save(update_fields=['submission_number', 'accepted_number']) class Migration(migrations.Migration): diff --git a/submission/migrations/0012_auto_20180501_0436.py b/submission/migrations/0012_auto_20180501_0436.py index b81087298..78ffbb720 100644 --- a/submission/migrations/0012_auto_20180501_0436.py +++ b/submission/migrations/0012_auto_20180501_0436.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0011_fix_submission_number'), ] @@ -16,7 +15,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='submission', name='id', - field=models.TextField(db_index=True, default=utils.shortcuts.rand_str, primary_key=True, serialize=False), + field=models.TextField( + db_index=True, + default=utils.shortcuts.rand_str, + primary_key=True, + serialize=False, + ), ), migrations.AlterField( model_name='submission', diff --git a/submission/models.py b/submission/models.py index 2918f8808..857eac677 100644 --- a/submission/models.py +++ b/submission/models.py @@ -41,7 +41,12 @@ class Submission(models.Model): ip = models.TextField(null=True) def check_user_permission(self, user, check_share=True): - if self.user_id == user.id or user.is_super_admin() or user.can_mgmt_all_problem() or self.problem.created_by_id == user.id: + if ( + self.user_id == user.id + or user.is_super_admin() + or user.can_mgmt_all_problem() + or self.problem.created_by_id == user.id + ): return True if check_share: @@ -52,8 +57,8 @@ def check_user_permission(self, user, check_share=True): return False class Meta: - db_table = "submission" - ordering = ("-create_time",) + db_table = 'submission' + ordering = ('-create_time',) def __str__(self): return self.id diff --git a/submission/serializers.py b/submission/serializers.py index 5e48f3e76..47a00816b 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -17,32 +17,31 @@ class ShareSubmissionSerializer(serializers.Serializer): class SubmissionModelSerializer(serializers.ModelSerializer): - class Meta: model = Submission - fields = "__all__" + fields = '__all__' # 不显示submission info的serializer, 用于ACM rule_type class SubmissionSafeModelSerializer(serializers.ModelSerializer): - problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") + problem = serializers.SlugRelatedField(read_only=True, slug_field='_id') class Meta: model = Submission - exclude = ("info", "contest", "ip") + exclude = ('info', 'contest', 'ip') class SubmissionListSerializer(serializers.ModelSerializer): - problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") + problem = serializers.SlugRelatedField(read_only=True, slug_field='_id') show_link = serializers.SerializerMethodField() def __init__(self, *args, **kwargs): - self.user = kwargs.pop("user", None) + self.user = kwargs.pop('user', None) super().__init__(*args, **kwargs) class Meta: model = Submission - exclude = ("info", "contest", "code", "ip") + exclude = ('info', 'contest', 'code', 'ip') def get_show_link(self, obj): # 没传user或为匿名user diff --git a/submission/tests.py b/submission/tests.py index 08ccbd46a..292c25257 100644 --- a/submission/tests.py +++ b/submission/tests.py @@ -5,25 +5,48 @@ from utils.api.tests import APITestCase from .models import Submission -DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "test
", "input_description": "test", - "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low", - "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {}, - "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C", - "spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85", - "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, - "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", - "input_size": 0, "score": 0}], - "rule_type": "ACM", "hint": "test
", "source": "test"} +DEFAULT_PROBLEM_DATA = { + '_id': 'A-110', + 'title': 'test', + 'description': 'test
', + 'input_description': 'test', + 'output_description': 'test', + 'time_limit': 1000, + 'memory_limit': 256, + 'difficulty': 'Low', + 'visible': True, + 'tags': ['test'], + 'languages': ['C', 'C++', 'Java', 'Python2'], + 'template': {}, + 'samples': [{'input': 'test', 'output': 'test'}], + 'spj': False, + 'spj_language': 'C', + 'spj_code': '', + 'test_case_id': '499b26290cc7994e0b497212e842ea85', + 'test_case_score': [ + { + 'output_name': '1.out', + 'input_name': '1.in', + 'output_size': 0, + 'stripped_output_md5': 'd41d8cd98f00b204e9800998ecf8427e', + 'input_size': 0, + 'score': 0, + } + ], + 'rule_type': 'ACM', + 'hint': 'test
', + 'source': 'test', +} DEFAULT_SUBMISSION_DATA = { - "problem_id": "1", - "user_id": 1, - "username": "test", - "code": "xxxxxxxxxxxxxx", - "result": -2, - "info": {}, - "language": "C", - "statistic_info": {} + 'problem_id': '1', + 'user_id': 1, + 'username': 'test', + 'code': 'xxxxxxxxxxxxxx', + 'result': -2, + 'info': {}, + 'language': 'C', + 'statistic_info': {}, } @@ -32,37 +55,37 @@ class SubmissionPrepare(APITestCase): def _create_problem_and_submission(self): - user = self.create_admin("test", "test123", login=False) + user = self.create_admin('test', 'test123', login=False) problem_data = deepcopy(DEFAULT_PROBLEM_DATA) - tags = problem_data.pop("tags") - problem_data["created_by"] = user + tags = problem_data.pop('tags') + problem_data['created_by'] = user self.problem = Problem.objects.create(**problem_data) for tag in tags: tag = ProblemTag.objects.create(name=tag) self.problem.tags.add(tag) self.problem.save() self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA) - self.submission_data["problem_id"] = self.problem.id + self.submission_data['problem_id'] = self.problem.id self.submission = Submission.objects.create(**self.submission_data) class SubmissionListTest(SubmissionPrepare): def setUp(self): self._create_problem_and_submission() - self.create_user("123", "345") - self.url = self.reverse("submission_list_api") + self.create_user('123', '345') + self.url = self.reverse('submission_list_api') def test_get_submission_list(self): - resp = self.client.get(self.url, data={"limit": "10"}) + resp = self.client.get(self.url, data={'limit': '10'}) self.assertSuccess(resp) -@mock.patch("submission.views.oj.judge_task.send") +@mock.patch('submission.views.oj.judge_task.send') class SubmissionAPITest(SubmissionPrepare): def setUp(self): self._create_problem_and_submission() - self.user = self.create_user("123", "test123") - self.url = self.reverse("submission_api") + self.user = self.create_user('123', 'test123') + self.url = self.reverse('submission_api') def test_create_submission(self, judge_task): resp = self.client.post(self.url, self.submission_data) @@ -70,9 +93,11 @@ def test_create_submission(self, judge_task): judge_task.assert_called() def test_create_submission_with_wrong_language(self, judge_task): - self.submission_data.update({"language": "Python3"}) + self.submission_data.update({'language': 'Python3'}) resp = self.client.post(self.url, self.submission_data) self.assertFailed(resp) - self.assertDictEqual(resp.data, {"error": "error", - "data": "Python3 is now allowed in the problem"}) + self.assertDictEqual( + resp.data, + {'error': 'error', 'data': 'Python3 is now allowed in the problem'}, + ) judge_task.assert_not_called() diff --git a/submission/urls/admin.py b/submission/urls/admin.py index bf86022f6..55de961ed 100644 --- a/submission/urls/admin.py +++ b/submission/urls/admin.py @@ -3,5 +3,9 @@ from ..views.admin import SubmissionRejudgeAPI urlpatterns = [ - url(r"^submission/rejudge?$", SubmissionRejudgeAPI.as_view(), name="submission_rejudge_api"), + url( + r'^submission/rejudge?$', + SubmissionRejudgeAPI.as_view(), + name='submission_rejudge_api', + ), ] diff --git a/submission/urls/oj.py b/submission/urls/oj.py index 49116b9ff..a975c8527 100644 --- a/submission/urls/oj.py +++ b/submission/urls/oj.py @@ -1,10 +1,23 @@ from django.conf.urls import url -from ..views.oj import SubmissionAPI, SubmissionListAPI, ContestSubmissionListAPI, SubmissionExistsAPI +from ..views.oj import ( + SubmissionAPI, + SubmissionListAPI, + ContestSubmissionListAPI, + SubmissionExistsAPI, +) urlpatterns = [ - url(r"^submission/?$", SubmissionAPI.as_view(), name="submission_api"), - url(r"^submissions/?$", SubmissionListAPI.as_view(), name="submission_list_api"), - url(r"^submission_exists/?$", SubmissionExistsAPI.as_view(), name="submission_exists"), - url(r"^contest_submissions/?$", ContestSubmissionListAPI.as_view(), name="contest_submission_list_api"), + url(r'^submission/?$', SubmissionAPI.as_view(), name='submission_api'), + url(r'^submissions/?$', SubmissionListAPI.as_view(), name='submission_list_api'), + url( + r'^submission_exists/?$', + SubmissionExistsAPI.as_view(), + name='submission_exists', + ), + url( + r'^contest_submissions/?$', + ContestSubmissionListAPI.as_view(), + name='contest_submission_list_api', + ), ] diff --git a/submission/views/admin.py b/submission/views/admin.py index 879725614..2e78c417f 100644 --- a/submission/views/admin.py +++ b/submission/views/admin.py @@ -1,5 +1,6 @@ from account.decorators import super_admin_required from judge.tasks import judge_task + # from judge.dispatcher import JudgeDispatcher from utils.api import APIView from ..models import Submission @@ -8,13 +9,15 @@ class SubmissionRejudgeAPI(APIView): @super_admin_required def get(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Parameter error, id is required") + return self.error('Parameter error, id is required') try: - submission = Submission.objects.select_related("problem").get(id=id, contest_id__isnull=True) + submission = Submission.objects.select_related('problem').get( + id=id, contest_id__isnull=True + ) except Submission.DoesNotExist: - return self.error("Submission does not exists") + return self.error('Submission does not exists') submission.statistic_info = {} submission.save() diff --git a/submission/views/oj.py b/submission/views/oj.py index 0f7c4a584..63ea8e6f0 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -4,6 +4,7 @@ from contest.models import ContestStatus, ContestRuleType from judge.tasks import judge_task from options.options import SysOptions + # from judge.dispatcher import JudgeDispatcher from problem.models import Problem, ProblemRuleType from utils.api import APIView, validate_serializer @@ -11,22 +12,26 @@ from utils.captcha import Captcha from utils.throttling import TokenBucket from ..models import Submission -from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, - ShareSubmissionSerializer) +from ..serializers import ( + CreateSubmissionSerializer, + SubmissionModelSerializer, + ShareSubmissionSerializer, +) from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer class SubmissionAPI(APIView): def throttling(self, request): # 使用 open_api 的请求暂不做限制 - auth_method = getattr(request, "auth_method", "") - if auth_method == "api_key": + auth_method = getattr(request, 'auth_method', '') + if auth_method == 'api_key': return - user_bucket = TokenBucket(key=str(request.user.id), - redis_conn=cache, **SysOptions.throttling["user"]) + user_bucket = TokenBucket( + key=str(request.user.id), redis_conn=cache, **SysOptions.throttling['user'] + ) can_consume, wait = user_bucket.consume() if not can_consume: - return "Please wait %d seconds" % (int(wait)) + return 'Please wait %d seconds' % (int(wait)) # ip_bucket = TokenBucket(key=request.session["ip"], # redis_conn=cache, **SysOptions.throttling["ip"]) @@ -34,23 +39,26 @@ def throttling(self, request): # if not can_consume: # return "Captcha is required" - @check_contest_permission(check_type="problems") + @check_contest_permission(check_type='problems') def check_contest_permission(self, request): contest = self.contest if contest.status == ContestStatus.CONTEST_ENDED: - return self.error("The contest have ended") + return self.error('The contest have ended') if not request.user.is_contest_admin(contest): - user_ip = ipaddress.ip_address(request.session.get("ip")) + user_ip = ipaddress.ip_address(request.session.get('ip')) if contest.allowed_ip_ranges: - if not any(user_ip in ipaddress.ip_network(cidr, strict=False) for cidr in contest.allowed_ip_ranges): - return self.error("Your IP is not allowed in this contest") + if not any( + user_ip in ipaddress.ip_network(cidr, strict=False) + for cidr in contest.allowed_ip_ranges + ): + return self.error('Your IP is not allowed in this contest') @validate_serializer(CreateSubmissionSerializer) @login_required def post(self, request): data = request.data hide_id = False - if data.get("contest_id"): + if data.get('contest_id'): error = self.check_contest_permission(request) if error: return error @@ -58,52 +66,63 @@ def post(self, request): if not contest.problem_details_permission(request.user): hide_id = True - if data.get("captcha"): - if not Captcha(request).check(data["captcha"]): - return self.error("Invalid captcha") + if data.get('captcha'): + if not Captcha(request).check(data['captcha']): + return self.error('Invalid captcha') error = self.throttling(request) if error: return self.error(error) try: - problem = Problem.objects.get(id=data["problem_id"], contest_id=data.get("contest_id"), visible=True) + problem = Problem.objects.get( + id=data['problem_id'], contest_id=data.get('contest_id'), visible=True + ) except Problem.DoesNotExist: - return self.error("Problem not exist") - if data["language"] not in problem.languages: + return self.error('Problem not exist') + if data['language'] not in problem.languages: return self.error(f"{data['language']} is now allowed in the problem") - submission = Submission.objects.create(user_id=request.user.id, - username=request.user.username, - language=data["language"], - code=data["code"], - problem_id=problem.id, - ip=request.session["ip"], - contest_id=data.get("contest_id")) + submission = Submission.objects.create( + user_id=request.user.id, + username=request.user.username, + language=data['language'], + code=data['code'], + problem_id=problem.id, + ip=request.session['ip'], + contest_id=data.get('contest_id'), + ) # use this for debug # JudgeDispatcher(submission.id, problem.id).judge() judge_task.send(submission.id, problem.id) if hide_id: return self.success() else: - return self.success({"submission_id": submission.id}) + return self.success({'submission_id': submission.id}) @login_required def get(self, request): - submission_id = request.GET.get("id") + submission_id = request.GET.get('id') if not submission_id: return self.error("Parameter id doesn't exist") try: - submission = Submission.objects.select_related("problem").get(id=submission_id) + submission = Submission.objects.select_related('problem').get( + id=submission_id + ) except Submission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user): - return self.error("No permission for this submission") + return self.error('No permission for this submission') - if submission.problem.rule_type == ProblemRuleType.OI or request.user.is_admin_role(): + if ( + submission.problem.rule_type == ProblemRuleType.OI + or request.user.is_admin_role() + ): submission_data = SubmissionModelSerializer(submission).data else: submission_data = SubmissionSafeModelSerializer(submission).data # 是否有权限取消共享 - submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False) + submission_data['can_unshare'] = submission.check_user_permission( + request.user, check_share=False + ) return self.success(submission_data) @validate_serializer(ShareSubmissionSerializer) @@ -113,67 +132,82 @@ def put(self, request): share submission """ try: - submission = Submission.objects.select_related("problem").get(id=request.data["id"]) + submission = Submission.objects.select_related('problem').get( + id=request.data['id'] + ) except Submission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user, check_share=False): - return self.error("No permission to share the submission") - if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY: - return self.error("Can not share submission now") - submission.shared = request.data["shared"] - submission.save(update_fields=["shared"]) + return self.error('No permission to share the submission') + if ( + submission.contest + and submission.contest.status == ContestStatus.CONTEST_UNDERWAY + ): + return self.error('Can not share submission now') + submission.shared = request.data['shared'] + submission.save(update_fields=['shared']) return self.success() class SubmissionListAPI(APIView): def get(self, request): - if not request.GET.get("limit"): - return self.error("Limit is needed") - if request.GET.get("contest_id"): - return self.error("Parameter error") - - submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by") - problem_id = request.GET.get("problem_id") - myself = request.GET.get("myself") - result = request.GET.get("result") - username = request.GET.get("username") + if not request.GET.get('limit'): + return self.error('Limit is needed') + if request.GET.get('contest_id'): + return self.error('Parameter error') + + submissions = Submission.objects.filter(contest_id__isnull=True).select_related( + 'problem__created_by' + ) + problem_id = request.GET.get('problem_id') + myself = request.GET.get('myself') + result = request.GET.get('result') + username = request.GET.get('username') if problem_id: try: - problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True) + problem = Problem.objects.get( + _id=problem_id, contest_id__isnull=True, visible=True + ) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = submissions.filter(problem=problem) - if (myself and myself == "1") or not SysOptions.submission_list_show_all: + if (myself and myself == '1') or not SysOptions.submission_list_show_all: submissions = submissions.filter(user_id=request.user.id) elif username: submissions = submissions.filter(username__icontains=username) if result: submissions = submissions.filter(result=result) data = self.paginate_data(request, submissions) - data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data + data['results'] = SubmissionListSerializer( + data['results'], many=True, user=request.user + ).data return self.success(data) class ContestSubmissionListAPI(APIView): - @check_contest_permission(check_type="submissions") + @check_contest_permission(check_type='submissions') def get(self, request): - if not request.GET.get("limit"): - return self.error("Limit is needed") + if not request.GET.get('limit'): + return self.error('Limit is needed') contest = self.contest - submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") - problem_id = request.GET.get("problem_id") - myself = request.GET.get("myself") - result = request.GET.get("result") - username = request.GET.get("username") + submissions = Submission.objects.filter(contest_id=contest.id).select_related( + 'problem__created_by' + ) + problem_id = request.GET.get('problem_id') + myself = request.GET.get('myself') + result = request.GET.get('result') + username = request.GET.get('username') if problem_id: try: - problem = Problem.objects.get(_id=problem_id, contest_id=contest.id, visible=True) + problem = Problem.objects.get( + _id=problem_id, contest_id=contest.id, visible=True + ) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = submissions.filter(problem=problem) - if myself and myself == "1": + if myself and myself == '1': submissions = submissions.filter(user_id=request.user.id) elif username: submissions = submissions.filter(username__icontains=username) @@ -186,18 +220,25 @@ def get(self, request): # 封榜的时候只能看到自己的提交 if contest.rule_type == ContestRuleType.ACM: - if not contest.real_time_rank and not request.user.is_contest_admin(contest): + if not contest.real_time_rank and not request.user.is_contest_admin( + contest + ): submissions = submissions.filter(user_id=request.user.id) data = self.paginate_data(request, submissions) - data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data + data['results'] = SubmissionListSerializer( + data['results'], many=True, user=request.user + ).data return self.success(data) class SubmissionExistsAPI(APIView): def get(self, request): - if not request.GET.get("problem_id"): - return self.error("Parameter error, problem_id is required") - return self.success(request.user.is_authenticated and - Submission.objects.filter(problem_id=request.GET["problem_id"], - user_id=request.user.id).exists()) + if not request.GET.get('problem_id'): + return self.error('Parameter error, problem_id is required') + return self.success( + request.user.is_authenticated + and Submission.objects.filter( + problem_id=request.GET['problem_id'], user_id=request.user.id + ).exists() + ) diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py index 745c69601..8827c6b1c 100644 --- a/utils/api/_serializers.py +++ b/utils/api/_serializers.py @@ -7,7 +7,7 @@ class UsernameSerializer(serializers.Serializer): real_name = serializers.SerializerMethodField() def __init__(self, *args, **kwargs): - self.need_real_name = kwargs.pop("need_real_name", False) + self.need_real_name = kwargs.pop('need_real_name', False) super().__init__(*args, **kwargs) def get_real_name(self, obj): diff --git a/utils/api/api.py b/utils/api/api.py index a603bfb97..fb192395f 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -7,7 +7,7 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View -logger = logging.getLogger("") +logger = logging.getLogger('') class APIError(Exception): @@ -18,10 +18,10 @@ def __init__(self, msg, err=None): class ContentType(object): - json_request = "application/json" - json_response = "application/json;charset=UTF-8" - url_encoded_request = "application/x-www-form-urlencoded" - binary_response = "application/octet-stream" + json_request = 'application/json' + json_response = 'application/json;charset=UTF-8' + url_encoded_request = 'application/x-www-form-urlencoded' + binary_response = 'application/octet-stream' class JSONParser(object): @@ -29,7 +29,7 @@ class JSONParser(object): @staticmethod def parse(body): - return json.loads(body.decode("utf-8")) + return json.loads(body.decode('utf-8')) class URLEncodedParser(object): @@ -59,15 +59,16 @@ class APIView(View): - self.response 返回一个django HttpResponse, 具体在self.response_class中实现 - parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据 """ + request_parsers = (JSONParser, URLEncodedParser) response_class = JSONResponse def _get_request_data(self, request): - if request.method not in ["GET", "DELETE"]: + if request.method not in ['GET', 'DELETE']: body = request.body - content_type = request.META.get("CONTENT_TYPE") + content_type = request.META.get('CONTENT_TYPE') if not content_type: - raise ValueError("content_type is required") + raise ValueError('content_type is required') for parser in self.request_parsers: if content_type.startswith(parser.content_type): break @@ -83,15 +84,15 @@ def response(self, data): return self.response_class.response(data) def success(self, data=None): - return self.response({"error": None, "data": data}) + return self.response({'error': None, 'data': data}) - def error(self, msg="error", err="error"): - return self.response({"error": err, "data": msg}) + def error(self, msg='error', err='error'): + return self.response({'error': err, 'data': msg}) - def extract_errors(self, errors, key="field"): + def extract_errors(self, errors, key='field'): if isinstance(errors, dict): if not errors: - return key, "Invalid field" + return key, 'Invalid field' key = list(errors.keys())[0] return self.extract_errors(errors.pop(key), key) elif isinstance(errors, list): @@ -101,14 +102,14 @@ def extract_errors(self, errors, key="field"): def invalid_serializer(self, serializer): key, error = self.extract_errors(serializer.errors) - if key == "non_field_errors": + if key == 'non_field_errors': msg = error else: - msg = f"{key}: {error}" - return self.error(err=f"invalid-{key}", msg=msg) + msg = f'{key}: {error}' + return self.error(err=f'invalid-{key}', msg=msg) def server_error(self): - return self.error(err="server-error", msg="server error") + return self.error(err='server-error', msg='server error') def paginate_data(self, request, query_set, object_serializer=None): """ @@ -118,25 +119,24 @@ def paginate_data(self, request, query_set, object_serializer=None): :return: """ try: - limit = int(request.GET.get("limit", "10")) + limit = int(request.GET.get('limit', '10')) except ValueError: limit = 10 if limit < 0 or limit > 250: limit = 10 try: - offset = int(request.GET.get("offset", "0")) + offset = int(request.GET.get('offset', '0')) except ValueError: offset = 0 if offset < 0: offset = 0 - results = query_set[offset:offset + limit] + results = query_set[offset : offset + limit] if object_serializer: count = query_set.count() results = object_serializer(results, many=True).data else: count = query_set.count() - data = {"results": results, - "total": count} + data = {'results': results, 'total': count} return data def dispatch(self, request, *args, **kwargs): @@ -144,13 +144,13 @@ def dispatch(self, request, *args, **kwargs): try: request.data = self._get_request_data(self.request) except ValueError as e: - return self.error(err="invalid-request", msg=str(e)) + return self.error(err='invalid-request', msg=str(e)) try: return super(APIView, self).dispatch(request, *args, **kwargs) except APIError as e: - ret = {"msg": e.msg} + ret = {'msg': e.msg} if e.err: - ret["err"] = e.err + ret['err'] = e.err return self.error(**ret) except Exception as e: logger.exception(e) @@ -169,6 +169,7 @@ def validate_serializer(serializer): def post(self, request): return self.success(request.data) """ + def validate(view_method): @functools.wraps(view_method) def handle(*args, **kwargs): diff --git a/utils/api/tests.py b/utils/api/tests.py index 0f0a79b84..c2f8e9d1a 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -8,9 +8,19 @@ class APITestCase(TestCase): client_class = APIClient - def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, - problem_permission=ProblemPermission.NONE): - user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission) + def create_user( + self, + username, + password, + admin_type=AdminType.REGULAR_USER, + login=True, + problem_permission=ProblemPermission.NONE, + ): + user = User.objects.create( + username=username, + admin_type=admin_type, + problem_permission=problem_permission, + ) user.set_password(password) UserProfile.objects.create(user=user) user.save() @@ -18,23 +28,34 @@ def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, log self.client.login(username=username, password=password) return user - def create_admin(self, username="admin", password="admin", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, - problem_permission=ProblemPermission.OWN, - login=login) - - def create_super_admin(self, username="root", password="root", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, - problem_permission=ProblemPermission.ALL, login=login) + def create_admin(self, username='admin', password='admin', login=True): + return self.create_user( + username=username, + password=password, + admin_type=AdminType.ADMIN, + problem_permission=ProblemPermission.OWN, + login=login, + ) + + def create_super_admin(self, username='root', password='root', login=True): + return self.create_user( + username=username, + password=password, + admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL, + login=login, + ) def reverse(self, url_name, *args, **kwargs): return reverse(url_name, *args, **kwargs) def assertSuccess(self, response): - if not response.data["error"] is None: - raise AssertionError("response with errors, response: " + str(response.data)) + if response.data['error'] is not None: + raise AssertionError( + 'response with errors, response: ' + str(response.data) + ) def assertFailed(self, response, msg=None): - self.assertTrue(response.data["error"] is not None) + self.assertTrue(response.data['error'] is not None) if msg: - self.assertEqual(response.data["data"], msg) + self.assertEqual(response.data['data'], msg) diff --git a/utils/captcha/__init__.py b/utils/captcha/__init__.py index 8b8375cdd..c94dd8096 100644 --- a/utils/captcha/__init__.py +++ b/utils/captcha/__init__.py @@ -24,8 +24,8 @@ def __init__(self, request): 初始化,设置各种属性 """ self.django_request = request - self.session_key = "_django_captcha_key" - self.captcha_expires_time = "_django_captcha_expires_time" + self.session_key = '_django_captcha_key' + self.captcha_expires_time = '_django_captcha_expires_time' # 验证码图片尺寸 self.img_width = 90 @@ -50,20 +50,33 @@ def _make_code(self): """ 生成随机数或随机字符串 """ - string = random.sample("abcdefghkmnpqrstuvwxyzABCDEFGHGKMNOPQRSTUVWXYZ23456789", 4) - self._set_answer("".join(string)) + string = random.sample( + 'abcdefghkmnpqrstuvwxyzABCDEFGHGKMNOPQRSTUVWXYZ23456789', 4 + ) + self._set_answer(''.join(string)) return string def get(self): """ 生成验证码图片,返回值为图片的bytes """ - background = (random.randrange(200, 255), random.randrange(200, 255), random.randrange(200, 255)) - code_color = (random.randrange(0, 50), random.randrange(0, 50), random.randrange(0, 50), 255) + background = ( + random.randrange(200, 255), + random.randrange(200, 255), + random.randrange(200, 255), + ) + code_color = ( + random.randrange(0, 50), + random.randrange(0, 50), + random.randrange(0, 50), + 255, + ) - font_path = os.path.join(os.path.normpath(os.path.dirname(__file__)), "timesbi.ttf") + font_path = os.path.join( + os.path.normpath(os.path.dirname(__file__)), 'timesbi.ttf' + ) - image = Image.new("RGB", (self.img_width, self.img_height), background) + image = Image.new('RGB', (self.img_width, self.img_height), background) code = self._make_code() font_size = self._get_font_size(code) draw = ImageDraw.Draw(image) @@ -75,19 +88,21 @@ def get(self): # 字符y坐标 y = random.randrange(1, 7) # 随机字符大小 - font = ImageFont.truetype(font_path.replace("\\", "/"), font_size + random.randrange(-3, 7)) + font = ImageFont.truetype( + font_path.replace('\\', '/'), font_size + random.randrange(-3, 7) + ) draw.text((x, y), i, font=font, fill=code_color) # 随机化字符之间的距离 字符粘连可以降低识别率 x += font_size * random.randrange(6, 8) / 10 - self.django_request.session[self.session_key] = "".join(code) + self.django_request.session[self.session_key] = ''.join(code) return image def check(self, code): """ 检查用户输入的验证码是否正确 """ - _code = self.django_request.session.get(self.session_key) or "" + _code = self.django_request.session.get(self.session_key) or '' if not _code: return False expires_time = self.django_request.session.get(self.captcha_expires_time) or 0 diff --git a/utils/constants.py b/utils/constants.py index 749b20da7..f38a7360a 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,35 +2,35 @@ class Choices: @classmethod def choices(cls): d = cls.__dict__ - return [d[item] for item in d.keys() if not item.startswith("__")] + return [d[item] for item in d.keys() if not item.startswith('__')] class ContestType: - PUBLIC_CONTEST = "Public" - PASSWORD_PROTECTED_CONTEST = "Password Protected" + PUBLIC_CONTEST = 'Public' + PASSWORD_PROTECTED_CONTEST = 'Password Protected' class ContestStatus: - CONTEST_NOT_START = "1" - CONTEST_ENDED = "-1" - CONTEST_UNDERWAY = "0" + CONTEST_NOT_START = '1' + CONTEST_ENDED = '-1' + CONTEST_UNDERWAY = '0' class ContestRuleType(Choices): - ACM = "ACM" - OI = "OI" + ACM = 'ACM' + OI = 'OI' class CacheKey: - waiting_queue = "waiting_queue" - contest_rank_cache = "contest_rank_cache" - website_config = "website_config" + waiting_queue = 'waiting_queue' + contest_rank_cache = 'contest_rank_cache' + website_config = 'website_config' class Difficulty(Choices): - LOW = "Low" - MID = "Mid" - HIGH = "High" + LOW = 'Low' + MID = 'Mid' + HIGH = 'High' -CONTEST_PASSWORD_SESSION_KEY = "contest_password" +CONTEST_PASSWORD_SESSION_KEY = 'contest_password' diff --git a/utils/management/commands/inituser.py b/utils/management/commands/inituser.py index e77693675..2efe823d0 100644 --- a/utils/management/commands/inituser.py +++ b/utils/management/commands/inituser.py @@ -6,39 +6,48 @@ class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument("--username", type=str) - parser.add_argument("--password", type=str) - parser.add_argument("--action", type=str) + parser.add_argument('--username', type=str) + parser.add_argument('--password', type=str) + parser.add_argument('--action', type=str) def handle(self, *args, **options): - username = options["username"] - password = options["password"] - action = options["action"] + username = options['username'] + password = options['password'] + action = options['action'] - if not(username and password and action): - self.stdout.write(self.style.ERROR("Invalid args")) + if not (username and password and action): + self.stdout.write(self.style.ERROR('Invalid args')) exit(1) - if action == "create_super_admin": + if action == 'create_super_admin': if User.objects.filter(id=1).exists(): - self.stdout.write(self.style.SUCCESS(f"User {username} exists, operation ignored")) + self.stdout.write( + self.style.SUCCESS(f'User {username} exists, operation ignored') + ) exit() - user = User.objects.create(username=username, admin_type=AdminType.SUPER_ADMIN, - problem_permission=ProblemPermission.ALL) + user = User.objects.create( + username=username, + admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL, + ) user.set_password(password) user.save() UserProfile.objects.create(user=user) - self.stdout.write(self.style.SUCCESS("User created")) - elif action == "reset": + self.stdout.write(self.style.SUCCESS('User created')) + elif action == 'reset': try: user = User.objects.get(username=username) user.set_password(password) user.save() - self.stdout.write(self.style.SUCCESS("Password is rested")) + self.stdout.write(self.style.SUCCESS('Password is rested')) except User.DoesNotExist: - self.stdout.write(self.style.ERROR(f"User {username} doesnot exist, operation ignored")) + self.stdout.write( + self.style.ERROR( + f'User {username} doesnot exist, operation ignored' + ) + ) exit(1) else: - raise ValueError("Invalid action") + raise ValueError('Invalid action') diff --git a/utils/migrate_data.py b/utils/migrate_data.py index b9b9bd306..73d307921 100644 --- a/utils/migrate_data.py +++ b/utils/migrate_data.py @@ -7,8 +7,8 @@ import hashlib from json.decoder import JSONDecodeError -sys.path.append("../") -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") +sys.path.append('../') +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings') django.setup() from django.conf import settings from account.models import User, UserProfile, AdminType, ProblemPermission @@ -17,14 +17,10 @@ admin_type_map = { 0: AdminType.REGULAR_USER, 1: AdminType.ADMIN, - 2: AdminType.SUPER_ADMIN + 2: AdminType.SUPER_ADMIN, } -languages_map = { - 1: "C", - 2: "C++", - 3: "Java" -} -email_regex = re.compile(r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)") +languages_map = {1: 'C', 2: 'C++', 3: 'Java'} +email_regex = re.compile(r'(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)') # pk -> name tags = {} @@ -37,25 +33,28 @@ def get_input_result(): while True: resp = input() - if resp not in ["yes", "no"]: - print("Please input yes or no") + if resp not in ['yes', 'no']: + print('Please input yes or no') continue - return resp == "yes" + return resp == 'yes' def set_problem_display_id_prefix(): while True: - print("Please input a prefix which will be used in all the imported problem's displayID") + print( + "Please input a prefix which will be used in all the imported problem's displayID" + ) print( "For example, if your input is 'old'(no quote), the problems' display id will be old1, old2, old3..\ninput:", - end="") + end='', + ) resp = input() if resp.strip(): return resp.strip() else: - print("Empty prefix detected, sure to do that? (yes/no)") + print('Empty prefix detected, sure to do that? (yes/no)') if get_input_result(): - return "" + return '' def get_stripped_output_md5(test_case_id, output_name): @@ -65,148 +64,158 @@ def get_stripped_output_md5(test_case_id, output_name): def get_test_case_score(test_case_id): - info_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, "info") + info_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, 'info') if not os.path.exists(info_path): return [] - with open(info_path, "r") as info_file: + with open(info_path, 'r') as info_file: info = json.load(info_file) test_case_score = [] need_rewrite = True - for test_case in info["test_cases"].values(): - if test_case.__contains__("stripped_output_md5"): + for test_case in info['test_cases'].values(): + if test_case.__contains__('stripped_output_md5'): need_rewrite = False - elif test_case.__contains__("striped_output_md5"): - test_case["stripped_output_md5"] = test_case.pop("striped_output_md5") + elif test_case.__contains__('striped_output_md5'): + test_case['stripped_output_md5'] = test_case.pop('striped_output_md5') else: - test_case["stripped_output_md5"] = get_stripped_output_md5(test_case_id, test_case["output_name"]) - test_case_score.append({"input_name": test_case["input_name"], - "output_name": test_case.get("output_name", "-"), - "score": 0}) + test_case['stripped_output_md5'] = get_stripped_output_md5( + test_case_id, test_case['output_name'] + ) + test_case_score.append( + { + 'input_name': test_case['input_name'], + 'output_name': test_case.get('output_name', '-'), + 'score': 0, + } + ) if need_rewrite: - with open(info_path, "w") as f: + with open(info_path, 'w') as f: f.write(json.dumps(info)) return test_case_score def import_users(): i = 0 - print("Find %d users in old data." % len(users.keys())) - print("import users now? (yes/no)") + print('Find %d users in old data.' % len(users.keys())) + print('import users now? (yes/no)') if get_input_result(): for data in users.values(): - if not email_regex.match(data["email"]): - print("%s will not be created due to invalid email: %s" % (data["username"], data["email"])) + if not email_regex.match(data['email']): + print( + '%s will not be created due to invalid email: %s' + % (data['username'], data['email']) + ) continue - data["username"] = data["username"].lower() - user, created = User.objects.get_or_create(username=data["username"]) + data['username'] = data['username'].lower() + user, created = User.objects.get_or_create(username=data['username']) if not created: - print("%s already exists, omitted" % user.username) + print('%s already exists, omitted' % user.username) continue - user.password = data["password"] - user.email = data["email"] - admin_type = admin_type_map[data["admin_type"]] + user.password = data['password'] + user.email = data['email'] + admin_type = admin_type_map[data['admin_type']] user.admin_type = admin_type if admin_type == AdminType.ADMIN: user.problem_permission = ProblemPermission.OWN elif admin_type == AdminType.SUPER_ADMIN: user.problem_permission = ProblemPermission.ALL user.save() - UserProfile.objects.create(user=user, real_name=data["real_name"]) + UserProfile.objects.create(user=user, real_name=data['real_name']) i += 1 - print("%s imported successfully" % user.username) - print("%d users have successfully imported\n" % i) + print('%s imported successfully' % user.username) + print('%d users have successfully imported\n' % i) def import_tags(): i = 0 - print("\nFind these tags in old data:") - print(", ".join(tags.values()), '\n') - print("import tags now? (yes/no)") + print('\nFind these tags in old data:') + print(', '.join(tags.values()), '\n') + print('import tags now? (yes/no)') if get_input_result(): for tagname in tags.values(): tag, created = ProblemTag.objects.get_or_create(name=tagname) if not created: - print("%s already exists, omitted" % tagname) + print('%s already exists, omitted' % tagname) else: - print("%s tag created successfully" % tagname) + print('%s tag created successfully' % tagname) i += 1 - print("%d tags have successfully imported\n" % i) + print('%d tags have successfully imported\n' % i) else: - print("Problem depends on problem_tags and users, exit..") + print('Problem depends on problem_tags and users, exit..') exit(1) def import_problems(): i = 0 - print("\nFind %d problems in old data" % len(problems)) + print('\nFind %d problems in old data' % len(problems)) prefix = set_problem_display_id_prefix() - print("import problems using prefix: %s? (yes/no)" % prefix) + print('import problems using prefix: %s? (yes/no)' % prefix) if get_input_result(): default_creator = User.objects.first() for data in problems: - data["_id"] = prefix + str(data.pop("id")) - if Problem.objects.filter(_id=data["_id"]).exists(): - print("%s has the same display_id with the db problem" % data["title"]) + data['_id'] = prefix + str(data.pop('id')) + if Problem.objects.filter(_id=data['_id']).exists(): + print('%s has the same display_id with the db problem' % data['title']) continue try: - creator_id = \ - User.objects.filter(username=users[data["created_by"]]["username"]).values_list("id", flat=True)[0] + creator_id = User.objects.filter( + username=users[data['created_by']]['username'] + ).values_list('id', flat=True)[0] except (User.DoesNotExist, IndexError): - print("The origin creator does not exist, set it to default_creator") + print('The origin creator does not exist, set it to default_creator') creator_id = default_creator.id - data["created_by_id"] = creator_id - data.pop("created_by") - data["difficulty"] = ProblemDifficulty.Mid - if data["spj_language"]: - data["spj_language"] = languages_map[data["spj_language"]] - data["samples"] = json.loads(data["samples"]) - data["languages"] = ["C", "C++"] - test_case_score = get_test_case_score(data["test_case_id"]) + data['created_by_id'] = creator_id + data.pop('created_by') + data['difficulty'] = ProblemDifficulty.Mid + if data['spj_language']: + data['spj_language'] = languages_map[data['spj_language']] + data['samples'] = json.loads(data['samples']) + data['languages'] = ['C', 'C++'] + test_case_score = get_test_case_score(data['test_case_id']) if not test_case_score: - print("%s test_case files don't exist, omitted" % data["title"]) + print("%s test_case files don't exist, omitted" % data['title']) continue - data["test_case_score"] = test_case_score - data["rule_type"] = ProblemRuleType.ACM - data["template"] = {} - data.pop("total_submit_number") - data.pop("total_accepted_number") - tag_ids = data.pop("tags") + data['test_case_score'] = test_case_score + data['rule_type'] = ProblemRuleType.ACM + data['template'] = {} + data.pop('total_submit_number') + data.pop('total_accepted_number') + tag_ids = data.pop('tags') problem = Problem.objects.create(**data) - problem.create_time = data["create_time"] + problem.create_time = data['create_time'] problem.save() for tag_id in tag_ids: tag, _ = ProblemTag.objects.get_or_create(name=tags[tag_id]) problem.tags.add(tag) i += 1 - print("%s imported successfully" % data["title"]) - print("%d problems have successfully imported" % i) + print('%s imported successfully' % data['title']) + print('%d problems have successfully imported' % i) -if __name__ == "__main__": +if __name__ == '__main__': if len(sys.argv) == 1: - print("Usage: python3 %s [old_data_path]" % sys.argv[0]) + print('Usage: python3 %s [old_data_path]' % sys.argv[0]) exit(0) data_path = sys.argv[1] if not os.path.isfile(data_path): - print("Data file does not exist") + print('Data file does not exist') exit(1) try: - with open(data_path, "r") as data_file: + with open(data_path, 'r') as data_file: old_data = json.load(data_file) except JSONDecodeError: print("Data file format error, ensure it's a valid json file!") exit(1) - print("Read old data successfully.\n") + print('Read old data successfully.\n') for obj in old_data: - if obj["model"] == "problem.problemtag": - tags[obj["pk"]] = obj["fields"]["name"] - elif obj["model"] == "account.user": - users[obj["pk"]] = obj["fields"] - elif obj["model"] == "problem.problem": - obj["fields"]["id"] = obj["pk"] - problems.append(obj["fields"]) + if obj['model'] == 'problem.problemtag': + tags[obj['pk']] = obj['fields']['name'] + elif obj['model'] == 'account.user': + users[obj['pk']] = obj['fields'] + elif obj['model'] == 'problem.problem': + obj['fields']['id'] = obj['pk'] + problems.append(obj['fields']) import_users() import_tags() import_problems() diff --git a/utils/models.py b/utils/models.py index 7577f1fab..e5f6936c0 100644 --- a/utils/models.py +++ b/utils/models.py @@ -7,4 +7,4 @@ class RichTextField(models.TextField): def get_prep_value(self, value): with XSSHtml() as parser: - return parser.clean(value or "") + return parser.clean(value or '') diff --git a/utils/serializers.py b/utils/serializers.py index c543c56f1..2de377292 100644 --- a/utils/serializers.py +++ b/utils/serializers.py @@ -5,7 +5,7 @@ class InvalidLanguage(serializers.ValidationError): def __init__(self, name): - super().__init__(detail=f"{name} is not a valid language") + super().__init__(detail=f'{name} is not a valid language') class LanguageNameChoiceField(serializers.CharField): diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 84e14fd28..05ad5f4f7 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -9,51 +9,60 @@ from envelopes import Envelope -def rand_str(length=32, type="lower_hex"): +def rand_str(length=32, type='lower_hex'): """ 生成指定长度的随机字符串或者数字, 可以用于密钥等安全场景 :param length: 字符串或者数字的长度 :param type: str 代表随机字符串,num 代表随机数字 :return: 字符串 """ - if type == "str": - return get_random_string(length, allowed_chars="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") - elif type == "lower_str": - return get_random_string(length, allowed_chars="abcdefghijklmnopqrstuvwxyz0123456789") - elif type == "lower_hex": - return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef") + if type == 'str': + return get_random_string( + length, + allowed_chars='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', + ) + elif type == 'lower_str': + return get_random_string( + length, allowed_chars='abcdefghijklmnopqrstuvwxyz0123456789' + ) + elif type == 'lower_hex': + return random.choice('123456789abcdef') + get_random_string( + length - 1, allowed_chars='0123456789abcdef' + ) else: - return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789") + return random.choice('123456789') + get_random_string( + length - 1, allowed_chars='0123456789' + ) def build_query_string(kv_data, ignore_none=True): # {"a": 1, "b": "test"} -> "?a=1&b=test" - query_string = "" + query_string = '' for k, v in kv_data.items(): if ignore_none is True and kv_data[k] is None: continue - if query_string != "": - query_string += "&" + if query_string != '': + query_string += '&' else: - query_string = "?" - query_string += (k + "=" + str(v)) + query_string = '?' + query_string += k + '=' + str(v) return query_string def img2base64(img): with BytesIO() as buf: - img.save(buf, "gif") + img.save(buf, 'gif') buf_str = buf.getvalue() - img_prefix = "data:image/png;base64," - b64_str = img_prefix + b64encode(buf_str).decode("utf-8") + img_prefix = 'data:image/png;base64,' + b64_str = img_prefix + b64encode(buf_str).decode('utf-8') return b64_str -def datetime2str(value, format="iso-8601"): - if format.lower() == "iso-8601": +def datetime2str(value, format='iso-8601'): + if format.lower() == 'iso-8601': value = value.isoformat() - if value.endswith("+00:00"): - value = value[:-6] + "Z" + if value.endswith('+00:00'): + value = value[:-6] + 'Z' return value return value.strftime(format) @@ -62,29 +71,34 @@ def timestamp2utcstr(value): return datetime.datetime.utcfromtimestamp(value).isoformat() -def natural_sort_key(s, _nsre=re.compile(r"(\d+)")): - return [int(text) if text.isdigit() else text.lower() - for text in re.split(_nsre, s)] +def natural_sort_key(s, _nsre=re.compile(r'(\d+)')): + return [ + int(text) if text.isdigit() else text.lower() for text in re.split(_nsre, s) + ] def send_email(smtp_config, from_name, to_email, to_name, subject, content): - envelope = Envelope(from_addr=(smtp_config["email"], from_name), - to_addr=(to_email, to_name), - subject=subject, - html_body=content) - return envelope.send(smtp_config["server"], - login=smtp_config["email"], - password=smtp_config["password"], - port=smtp_config["port"], - tls=smtp_config["tls"]) - - -def get_env(name, default=""): + envelope = Envelope( + from_addr=(smtp_config['email'], from_name), + to_addr=(to_email, to_name), + subject=subject, + html_body=content, + ) + return envelope.send( + smtp_config['server'], + login=smtp_config['email'], + password=smtp_config['password'], + port=smtp_config['port'], + tls=smtp_config['tls'], + ) + + +def get_env(name, default=''): return os.environ.get(name, default) def DRAMATIQ_WORKER_ARGS(time_limit=3600_000, max_retries=0, max_age=7200_000): - return {"max_retries": max_retries, "time_limit": time_limit, "max_age": max_age} + return {'max_retries': max_retries, 'time_limit': time_limit, 'max_age': max_age} def check_is_id(value): diff --git a/utils/throttling.py b/utils/throttling.py index bab1184a7..62cae16b9 100644 --- a/utils/throttling.py +++ b/utils/throttling.py @@ -5,6 +5,7 @@ class TokenBucket: """ 注意:对于单个key的操作不是线程安全的 """ + def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): """ :param capacity: 最大容量 @@ -18,8 +19,8 @@ def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): self._default_capacity = default_capacity self._redis_conn = redis_conn - self._last_capacity_key = "last_capacity" - self._last_timestamp_key = "last_timestamp" + self._last_capacity_key = 'last_capacity' + self._last_timestamp_key = 'last_timestamp' def _init_key(self): self._last_capacity = self._default_capacity diff --git a/utils/urls.py b/utils/urls.py index 7e0128e5c..bd99df08f 100644 --- a/utils/urls.py +++ b/utils/urls.py @@ -3,6 +3,6 @@ from .views import SimditorImageUploadAPIView, SimditorFileUploadAPIView urlpatterns = [ - url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image"), - url(r"^upload_file/?$", SimditorFileUploadAPIView.as_view(), name="upload_file") + url(r'^upload_image/?$', SimditorImageUploadAPIView.as_view(), name='upload_image'), + url(r'^upload_file/?$', SimditorFileUploadAPIView.as_view(), name='upload_file'), ] diff --git a/utils/views.py b/utils/views.py index ce30c717c..acb1e3abb 100644 --- a/utils/views.py +++ b/utils/views.py @@ -14,34 +14,34 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView): def post(self, request): form = ImageUploadForm(request.POST, request.FILES) if form.is_valid(): - img = form.cleaned_data["image"] + img = form.cleaned_data['image'] else: - return self.response({ - "success": False, - "msg": "Upload failed", - "file_path": ""}) + return self.response( + {'success': False, 'msg': 'Upload failed', 'file_path': ''} + ) suffix = os.path.splitext(img.name)[-1].lower() - if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: - return self.response({ - "success": False, - "msg": "Unsupported file format", - "file_path": ""}) + if suffix not in ['.gif', '.jpg', '.jpeg', '.bmp', '.png']: + return self.response( + {'success': False, 'msg': 'Unsupported file format', 'file_path': ''} + ) img_name = rand_str(10) + suffix try: - with open(os.path.join(settings.UPLOAD_DIR, img_name), "wb") as imgFile: + with open(os.path.join(settings.UPLOAD_DIR, img_name), 'wb') as imgFile: for chunk in img: imgFile.write(chunk) except IOError as e: logger.error(e) - return self.response({ - "success": False, - "msg": "Upload Error", - "file_path": ""}) - return self.response({ - "success": True, - "msg": "Success", - "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) + return self.response( + {'success': False, 'msg': 'Upload Error', 'file_path': ''} + ) + return self.response( + { + 'success': True, + 'msg': 'Success', + 'file_path': f'{settings.UPLOAD_PREFIX}/{img_name}', + } + ) class SimditorFileUploadAPIView(CSRFExemptAPIView): @@ -50,26 +50,24 @@ class SimditorFileUploadAPIView(CSRFExemptAPIView): def post(self, request): form = FileUploadForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] + file = form.cleaned_data['file'] else: - return self.response({ - "success": False, - "msg": "Upload failed" - }) + return self.response({'success': False, 'msg': 'Upload failed'}) suffix = os.path.splitext(file.name)[-1].lower() file_name = rand_str(10) + suffix try: - with open(os.path.join(settings.UPLOAD_DIR, file_name), "wb") as f: + with open(os.path.join(settings.UPLOAD_DIR, file_name), 'wb') as f: for chunk in file: f.write(chunk) except IOError as e: logger.error(e) - return self.response({ - "success": False, - "msg": "Upload Error"}) - return self.response({ - "success": True, - "msg": "Success", - "file_path": f"{settings.UPLOAD_PREFIX}/{file_name}", - "file_name": file.name}) + return self.response({'success': False, 'msg': 'Upload Error'}) + return self.response( + { + 'success': True, + 'msg': 'Success', + 'file_path': f'{settings.UPLOAD_PREFIX}/{file_name}', + 'file_name': file.name, + } + ) diff --git a/utils/xss_filter.py b/utils/xss_filter.py index fe4f7aa8c..904f29a1c 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -31,19 +31,63 @@ class XSSHtml(HTMLParser): - allow_tags = ['a', 'img', 'br', 'strong', 'b', 'code', 'pre', - 'p', 'div', 'em', 'span', 'h1', 'h2', 'h3', 'h4', - 'h5', 'h6', 'blockquote', 'ul', 'ol', 'tr', 'th', 'td', - 'hr', 'li', 'u', 'embed', 's', 'table', 'thead', 'tbody', - 'caption', 'small', 'q', 'sup', 'sub', 'font'] - common_attrs = ["style", "class", "name"] - nonend_tags = ["img", "hr", "br", "embed"] + allow_tags = [ + 'a', + 'img', + 'br', + 'strong', + 'b', + 'code', + 'pre', + 'p', + 'div', + 'em', + 'span', + 'h1', + 'h2', + 'h3', + 'h4', + 'h5', + 'h6', + 'blockquote', + 'ul', + 'ol', + 'tr', + 'th', + 'td', + 'hr', + 'li', + 'u', + 'embed', + 's', + 'table', + 'thead', + 'tbody', + 'caption', + 'small', + 'q', + 'sup', + 'sub', + 'font', + ] + common_attrs = ['style', 'class', 'name'] + nonend_tags = ['img', 'hr', 'br', 'embed'] tags_own_attrs = { - "img": ["src", "width", "height", "alt", "align"], - "a": ["href", "target", "rel", "title"], - "embed": ["src", "width", "height", "type", "allowfullscreen", "loop", "play", "wmode", "menu"], - "table": ["border", "cellpadding", "cellspacing"], - "font": ["color"] + 'img': ['src', 'width', 'height', 'alt', 'align'], + 'a': ['href', 'target', 'rel', 'title'], + 'embed': [ + 'src', + 'width', + 'height', + 'type', + 'allowfullscreen', + 'loop', + 'play', + 'wmode', + 'menu', + ], + 'table': ['border', 'cellpadding', 'cellspacing'], + 'font': ['color'], } def __init__(self, allows=[]): @@ -86,13 +130,13 @@ def handle_starttag(self, tag, attrs): attdict[attr[0]] = attr[1] attdict = self._wash_attr(attdict, tag) - if hasattr(self, "node_%s" % tag): - attdict = getattr(self, "node_%s" % tag)(attdict) + if hasattr(self, 'node_%s' % tag): + attdict = getattr(self, 'node_%s' % tag)(attdict) else: attdict = self.node_default(attdict) attrs = [] - for (key, value) in attdict.items(): + for key, value in attdict.items(): attrs.append('%s="%s"' % (key, self._htmlspecialchars(value))) attrs = (' ' + ' '.join(attrs)) if attrs else '' self.result.append('<' + tag + attrs + end_diagonal + '>') @@ -107,11 +151,11 @@ def handle_data(self, data): def handle_entityref(self, name): if name.isalpha(): - self.result.append("&%s;" % name) + self.result.append('&%s;' % name) def handle_charref(self, name): if name.isdigit(): - self.result.append("%s;" % name) + self.result.append('%s;' % name) def node_default(self, attrs): attrs = self._common_attr(attrs) @@ -119,44 +163,45 @@ def node_default(self, attrs): def node_a(self, attrs): attrs = self._common_attr(attrs) - attrs = self._get_link(attrs, "href") - attrs = self._set_attr_default(attrs, "target", "_blank") - attrs = self._limit_attr(attrs, { - "target": ["_blank", "_self"] - }) + attrs = self._get_link(attrs, 'href') + attrs = self._set_attr_default(attrs, 'target', '_blank') + attrs = self._limit_attr(attrs, {'target': ['_blank', '_self']}) return attrs def node_embed(self, attrs): attrs = self._common_attr(attrs) - attrs = self._get_link(attrs, "src") - attrs = self._limit_attr(attrs, { - "type": ["application/x-shockwave-flash"], - "wmode": ["transparent", "window", "opaque"], - "play": ["true", "false"], - "loop": ["true", "false"], - "menu": ["true", "false"], - "allowfullscreen": ["true", "false"] - }) - attrs["allowscriptaccess"] = "never" - attrs["allownetworking"] = "none" + attrs = self._get_link(attrs, 'src') + attrs = self._limit_attr( + attrs, + { + 'type': ['application/x-shockwave-flash'], + 'wmode': ['transparent', 'window', 'opaque'], + 'play': ['true', 'false'], + 'loop': ['true', 'false'], + 'menu': ['true', 'false'], + 'allowfullscreen': ['true', 'false'], + }, + ) + attrs['allowscriptaccess'] = 'never' + attrs['allownetworking'] = 'none' return attrs def _true_url(self, url): - prog = re.compile(r"(^(http|https|ftp)://.+)|(^/)", re.I | re.S) + prog = re.compile(r'(^(http|https|ftp)://.+)|(^/)', re.I | re.S) if prog.match(url): return url else: - return "http://%s" % url + return 'http://%s' % url def _true_style(self, style): if style: - style = re.sub(r"(\\||/\*|\*/)", "_", style) - style = re.sub(r"e.*x.*p.*r.*e.*s.*s.*i.*o.*n", "_", style) + style = re.sub(r'(\\||/\*|\*/)', '_', style) + style = re.sub(r'e.*x.*p.*r.*e.*s.*s.*i.*o.*n', '_', style) return style def _get_style(self, attrs): - if "style" in attrs: - attrs["style"] = self._true_style(attrs.get("style")) + if 'style' in attrs: + attrs['style'] = self._true_style(attrs.get('style')) return attrs def _get_link(self, attrs, name): @@ -185,24 +230,28 @@ def _set_attr_default(self, attrs, name, default=''): return attrs def _limit_attr(self, attrs, limit={}): - for (key, value) in limit.items(): + for key, value in limit.items(): if key in attrs and attrs[key] not in value: del attrs[key] return attrs def _htmlspecialchars(self, html): - return html.replace("<", "<") \ - .replace(">", ">") \ - .replace('"', """) \ - .replace("'", "'") + return ( + html.replace('<', '<') + .replace('>', '>') + .replace('"', '"') + .replace("'", ''') + ) -if "__main__" == __name__: +if '__main__' == __name__: with XSSHtml() as parser: - ret = parser.clean(""">M
- """) + """ + ) print(ret)