diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 5af648026..000000000 --- a/.travis.yml +++ /dev/null @@ -1,17 +0,0 @@ -language: python -python: - - "3.8" -services: - - docker - - postgresql -install: - - pip install -r deploy/requirements.txt - - echo `cat /dev/urandom | head -1 | md5sum | head -c 32` > data/config/secret.key - - ./init_db.sh -script: - - docker ps -a - - flake8 --config=./.flake8 . - - coverage run --include="$PWD/*" manage.py test - - coverage report -notifications: - slack: onlinejudgeteam:BzBz8UFgmS5crpiblof17K2W diff --git a/account/decorators.py b/account/decorators.py index 0b6f236d4..107121e32 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -17,17 +17,17 @@ def __get__(self, obj, obj_type): return functools.partial(self.__call__, obj) def error(self, data): - return JSONResponse.response({"error": "permission-denied", "data": data}) + return JSONResponse.response({'error': 'permission-denied', 'data': data}) def __call__(self, *args, **kwargs): self.request = args[1] if self.check_permission(): if self.request.user.is_disabled: - return self.error("Your account is disabled") + return self.error('Your account is disabled') return self.func(*args, **kwargs) else: - return self.error("Please login first") + return self.error('Please login first') def check_permission(self): raise NotImplementedError() @@ -67,13 +67,18 @@ def check_contest_password(password, contest_password): else: # sig#timestamp 这种形式的密码也可以,但是在界面上没提供支持 # sig = sha256(contest_password + timestamp)[:8] - if "#" in password: - s = password.split("#") + if '#' in password: + s = password.split('#') if len(s) != 2: return False sig, ts = s[0], s[1] - if sig == hashlib.sha256((contest_password + ts).encode("utf-8")).hexdigest()[:8]: + if ( + sig + == hashlib.sha256((contest_password + ts).encode('utf-8')).hexdigest()[ + :8 + ] + ): try: ts = int(ts) except Exception: @@ -85,7 +90,7 @@ def check_contest_password(password, contest_password): return False -def check_contest_permission(check_type="details"): +def check_contest_permission(check_type='details'): """ 只供Class based view 使用,检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions 若通过验证,在view中可通过self.contest获得该contest @@ -96,22 +101,24 @@ def _check_permission(*args, **kwargs): self = args[0] request = args[1] user = request.user - if request.data.get("contest_id"): - contest_id = request.data["contest_id"] + if request.data.get('contest_id'): + contest_id = request.data['contest_id'] else: - contest_id = request.GET.get("contest_id") + contest_id = request.GET.get('contest_id') if not contest_id: - return self.error("Parameter error, contest_id is required") + return self.error('Parameter error, contest_id is required') try: # use self.contest to avoid query contest again in view. - self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) + self.contest = Contest.objects.select_related('created_by').get( + id=contest_id, visible=True + ) except Contest.DoesNotExist: return self.error("Contest %s doesn't exist" % contest_id) # Anonymous if not user.is_authenticated: - return self.error("Please login first.") + return self.error('Please login first.') # creator or owner if user.is_contest_admin(self.contest): @@ -119,25 +126,40 @@ def _check_permission(*args, **kwargs): if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: # password error - if not check_contest_password(request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get(self.contest.id), self.contest.password): - return self.error("Wrong password or password expired") + if not check_contest_password( + request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get( + self.contest.id + ), + self.contest.password, + ): + return self.error('Wrong password or password expired') # regular user get contest problems, ranks etc. before contest started - if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details": - return self.error("Contest has not started yet.") + if ( + self.contest.status == ContestStatus.CONTEST_NOT_START + and check_type != 'details' + ): + return self.error('Contest has not started yet.') # check does user have permission to get ranks, submissions in OI Contest - if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI: - if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"): - return self.error(f"No permission to get {check_type}") + if ( + self.contest.status == ContestStatus.CONTEST_UNDERWAY + and self.contest.rule_type == ContestRuleType.OI + ): + if not self.contest.real_time_rank and ( + check_type == 'ranks' or check_type == 'submissions' + ): + return self.error(f'No permission to get {check_type}') return func(*args, **kwargs) + return _check_permission + return decorator def ensure_created_by(obj, user): - e = APIError(msg=f"{obj.__class__.__name__} does not exist") + e = APIError(msg=f'{obj.__class__.__name__} does not exist') if not user.is_admin_role(): raise e if user.is_super_admin(): diff --git a/account/middleware.py b/account/middleware.py index 05b732640..822768497 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -1,4 +1,3 @@ -from django.conf import settings from django.db import connection from django.utils.timezone import now from django.utils.deprecation import MiddlewareMixin @@ -9,12 +8,14 @@ class APITokenAuthMiddleware(MiddlewareMixin): def process_request(self, request): - appkey = request.META.get("HTTP_APPKEY") + appkey = request.META.get('HTTP_APPKEY') if appkey: try: - request.user = User.objects.get(open_api_appkey=appkey, open_api=True, is_disabled=False) + request.user = User.objects.get( + open_api_appkey=appkey, open_api=True, is_disabled=False + ) request.csrf_processing_done = True - request.auth_method = "api_key" + request.auth_method = 'api_key' except User.DoesNotExist: pass @@ -25,12 +26,12 @@ def process_request(self, request): if forwarded: request.ip = forwarded.split(',')[0].strip() else: - request.ip = request.META["REMOTE_ADDR"] + request.ip = request.META['REMOTE_ADDR'] if request.user.is_authenticated: session = request.session - session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") - session["ip"] = request.ip - session["last_activity"] = now() + session['user_agent'] = request.META.get('HTTP_USER_AGENT', '') + session['ip'] = request.ip + session['last_activity'] = now() user_sessions = request.user.session_keys if session.session_key not in user_sessions: user_sessions.append(session.session_key) @@ -40,18 +41,20 @@ def process_request(self, request): class AdminRoleRequiredMiddleware(MiddlewareMixin): def process_request(self, request): path = request.path_info - if path.startswith("/admin/") or path.startswith("/api/admin/"): + if path.startswith('/admin/') or path.startswith('/api/admin/'): if not (request.user.is_authenticated and request.user.is_admin_role()): - return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) + return JSONResponse.response( + {'error': 'login-required', 'data': 'Please login in first'} + ) class LogSqlMiddleware(MiddlewareMixin): def process_response(self, request, response): - print("\033[94m", "#" * 30, "\033[0m") + print('\033[94m', '#' * 30, '\033[0m') time_threshold = 0.03 for query in connection.queries: - if float(query["time"]) > time_threshold: - print("\033[93m", query, "\n", "-" * 30, "\033[0m") + if float(query['time']) > time_threshold: + print('\033[93m', query, '\n', '-' * 30, '\033[0m') else: - print(query, "\n", "-" * 30) + print(query, '\n', '-' * 30) return response diff --git a/account/migrations/0001_initial.py b/account/migrations/0001_initial.py index e1e588eec..06a5c46fb 100644 --- a/account/migrations/0001_initial.py +++ b/account/migrations/0001_initial.py @@ -11,19 +11,30 @@ class Migration(migrations.Migration): - initial = True - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( name='User', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('password', models.CharField(max_length=128, verbose_name='password')), - ('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')), + ( + 'last_login', + models.DateTimeField( + blank=True, null=True, verbose_name='last login' + ), + ), ('username', models.CharField(max_length=30, unique=True)), ('real_name', models.CharField(max_length=30, null=True)), ('email', models.EmailField(max_length=254, null=True)), @@ -48,20 +59,37 @@ class Migration(migrations.Migration): migrations.CreateModel( name='UserProfile', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('problems_status', jsonfield.fields.JSONField(default={})), - ('avatar', models.CharField(default="default.png", max_length=50)), + ('avatar', models.CharField(default='default.png', max_length=50)), ('blog', models.URLField(blank=True, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)), ('accepted_problem_number', models.IntegerField(default=0)), ('submission_number', models.IntegerField(default=0)), - ('phone_number', models.CharField(blank=True, max_length=15, null=True)), + ( + 'phone_number', + models.CharField(blank=True, max_length=15, null=True), + ), ('school', models.CharField(blank=True, max_length=200, null=True)), ('major', models.CharField(blank=True, max_length=200, null=True)), ('student_id', models.CharField(blank=True, max_length=15, null=True)), ('time_zone', models.CharField(blank=True, max_length=32, null=True)), ('language', models.CharField(blank=True, max_length=32, null=True)), - ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'user', + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'user_profile', diff --git a/account/migrations/0002_auto_20170209_1028.py b/account/migrations/0002_auto_20170209_1028.py index c5698be22..2d59c815f 100644 --- a/account/migrations/0002_auto_20170209_1028.py +++ b/account/migrations/0002_auto_20170209_1028.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0001_initial'), ] diff --git a/account/migrations/0003_userprofile_total_score.py b/account/migrations/0003_userprofile_total_score.py index f7efe8835..dd12d6933 100644 --- a/account/migrations/0003_userprofile_total_score.py +++ b/account/migrations/0003_userprofile_total_score.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0002_auto_20170209_1028'), ] @@ -25,5 +24,5 @@ class Migration(migrations.Migration): migrations.RemoveField( model_name='userprofile', name='time_zone', - ) + ), ] diff --git a/account/migrations/0005_auto_20170830_1154.py b/account/migrations/0005_auto_20170830_1154.py index 1ba8a9473..d7fa9bb19 100644 --- a/account/migrations/0005_auto_20170830_1154.py +++ b/account/migrations/0005_auto_20170830_1154.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0003_userprofile_total_score'), ] diff --git a/account/migrations/0006_user_session_keys.py b/account/migrations/0006_user_session_keys.py index 6dc991a21..4ece08c85 100644 --- a/account/migrations/0006_user_session_keys.py +++ b/account/migrations/0006_user_session_keys.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0005_auto_20170830_1154'), ] diff --git a/account/migrations/0008_auto_20171011_1214.py b/account/migrations/0008_auto_20171011_1214.py index f27cac8e8..e3c0ef4cf 100644 --- a/account/migrations/0008_auto_20171011_1214.py +++ b/account/migrations/0008_auto_20171011_1214.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0006_user_session_keys'), ] @@ -70,7 +69,9 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='userprofile', name='avatar', - field=models.CharField(default='/static/avatar/default.png', max_length=256), + field=models.CharField( + default='/static/avatar/default.png', max_length=256 + ), ), migrations.AlterField( model_name='userprofile', diff --git a/account/migrations/0009_auto_20171125_1514.py b/account/migrations/0009_auto_20171125_1514.py index b476b7892..b71bd88b3 100644 --- a/account/migrations/0009_auto_20171125_1514.py +++ b/account/migrations/0009_auto_20171125_1514.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0008_auto_20171011_1214'), ] @@ -15,6 +14,8 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='userprofile', name='avatar', - field=models.CharField(default='/public/avatar/default.png', max_length=256), + field=models.CharField( + default='/public/avatar/default.png', max_length=256 + ), ), ] diff --git a/account/migrations/0010_auto_20180501_0436.py b/account/migrations/0010_auto_20180501_0436.py index 1e7d65c6f..f29dc9281 100644 --- a/account/migrations/0010_auto_20180501_0436.py +++ b/account/migrations/0010_auto_20180501_0436.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0009_auto_20171125_1514'), ] diff --git a/account/migrations/0011_auto_20180501_0456.py b/account/migrations/0011_auto_20180501_0456.py index 5f170fbda..8f0276b28 100644 --- a/account/migrations/0011_auto_20180501_0456.py +++ b/account/migrations/0011_auto_20180501_0456.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0010_auto_20180501_0436'), ] diff --git a/account/migrations/0012_userprofile_language.py b/account/migrations/0012_userprofile_language.py index a87948866..885412c98 100644 --- a/account/migrations/0012_userprofile_language.py +++ b/account/migrations/0012_userprofile_language.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('account', '0011_auto_20180501_0456'), ] diff --git a/account/models.py b/account/models.py index afd644edb..043b4ae8c 100644 --- a/account/models.py +++ b/account/models.py @@ -5,22 +5,22 @@ class AdminType(object): - REGULAR_USER = "Regular User" - ADMIN = "Admin" - SUPER_ADMIN = "Super Admin" + REGULAR_USER = 'Regular User' + ADMIN = 'Admin' + SUPER_ADMIN = 'Super Admin' class ProblemPermission(object): - NONE = "None" - OWN = "Own" - ALL = "All" + NONE = 'None' + OWN = 'Own' + ALL = 'All' class UserManager(models.Manager): use_in_migrations = True def get_by_natural_key(self, username): - return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username}) + return self.get(**{f'{self.model.USERNAME_FIELD}__iexact': username}) class User(AbstractBaseUser): @@ -42,7 +42,7 @@ class User(AbstractBaseUser): open_api_appkey = models.TextField(null=True) is_disabled = models.BooleanField(default=False) - USERNAME_FIELD = "username" + USERNAME_FIELD = 'username' REQUIRED_FIELDS = [] objects = UserManager() @@ -60,10 +60,12 @@ def can_mgmt_all_problem(self): return self.problem_permission == ProblemPermission.ALL def is_contest_admin(self, contest): - return self.is_authenticated and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN) + return self.is_authenticated and ( + contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN + ) class Meta: - db_table = "user" + db_table = 'user' class UserProfile(models.Model): @@ -88,7 +90,7 @@ class UserProfile(models.Model): oi_problems_status = JSONField(default=dict) real_name = models.TextField(null=True) - avatar = models.TextField(default=f"{settings.AVATAR_URI_PREFIX}/default.png") + avatar = models.TextField(default=f'{settings.AVATAR_URI_PREFIX}/default.png') blog = models.URLField(null=True) mood = models.TextField(null=True) github = models.TextField(null=True) @@ -102,18 +104,18 @@ class UserProfile(models.Model): submission_number = models.IntegerField(default=0) def add_accepted_problem_number(self): - self.accepted_number = models.F("accepted_number") + 1 + self.accepted_number = models.F('accepted_number') + 1 self.save() def add_submission_number(self): - self.submission_number = models.F("submission_number") + 1 + self.submission_number = models.F('submission_number') + 1 self.save() # 计算总分时, 应先减掉上次该题所得分数, 然后再加上本次所得分数 def add_score(self, this_time_score, last_time_score=None): last_time_score = last_time_score or 0 - self.total_score = models.F("total_score") - last_time_score + this_time_score + self.total_score = models.F('total_score') - last_time_score + this_time_score self.save() class Meta: - db_table = "user_profile" + db_table = 'user_profile' diff --git a/account/serializers.py b/account/serializers.py index faec66e2e..918791c08 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -45,7 +45,8 @@ class GenerateUserSerializer(serializers.Serializer): class ImportUserSeralizer(serializers.Serializer): users = serializers.ListField( - child=serializers.ListField(child=serializers.CharField(max_length=64))) + child=serializers.ListField(child=serializers.CharField(max_length=64)) + ) class UserAdminSerializer(serializers.ModelSerializer): @@ -53,8 +54,19 @@ class UserAdminSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ["id", "username", "email", "admin_type", "problem_permission", "real_name", - "create_time", "last_login", "two_factor_auth", "open_api", "is_disabled"] + fields = [ + 'id', + 'username', + 'email', + 'admin_type', + 'problem_permission', + 'real_name', + 'create_time', + 'last_login', + 'two_factor_auth', + 'open_api', + 'is_disabled', + ] def get_real_name(self, obj): return obj.userprofile.real_name @@ -63,8 +75,18 @@ def get_real_name(self, obj): class UserSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ["id", "username", "email", "admin_type", "problem_permission", - "create_time", "last_login", "two_factor_auth", "open_api", "is_disabled"] + fields = [ + 'id', + 'username', + 'email', + 'admin_type', + 'problem_permission', + 'create_time', + 'last_login', + 'two_factor_auth', + 'open_api', + 'is_disabled', + ] class UserProfileSerializer(serializers.ModelSerializer): @@ -73,10 +95,10 @@ class UserProfileSerializer(serializers.ModelSerializer): class Meta: model = UserProfile - fields = "__all__" + fields = '__all__' def __init__(self, *args, **kwargs): - self.show_real_name = kwargs.pop("show_real_name", False) + self.show_real_name = kwargs.pop('show_real_name', False) super(UserProfileSerializer, self).__init__(*args, **kwargs) def get_real_name(self, obj): @@ -87,11 +109,16 @@ class EditUserSerializer(serializers.Serializer): id = serializers.IntegerField() username = serializers.CharField(max_length=32) real_name = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) - password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None) + password = serializers.CharField( + min_length=6, allow_blank=True, required=False, default=None + ) email = serializers.EmailField(max_length=64) - admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN)) - problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN, - ProblemPermission.ALL)) + admin_type = serializers.ChoiceField( + choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN) + ) + problem_permission = serializers.ChoiceField( + choices=(ProblemPermission.NONE, ProblemPermission.OWN, ProblemPermission.ALL) + ) open_api = serializers.BooleanField() two_factor_auth = serializers.BooleanField() is_disabled = serializers.BooleanField() @@ -140,4 +167,4 @@ class RankInfoSerializer(serializers.ModelSerializer): class Meta: model = UserProfile - fields = "__all__" + fields = '__all__' diff --git a/account/tasks.py b/account/tasks.py index 513599914..8707d6721 100644 --- a/account/tasks.py +++ b/account/tasks.py @@ -12,11 +12,13 @@ def send_email_async(from_name, to_email, to_name, subject, content): if not SysOptions.smtp_config: return try: - send_email(smtp_config=SysOptions.smtp_config, - from_name=from_name, - to_email=to_email, - to_name=to_name, - subject=subject, - content=content) + send_email( + smtp_config=SysOptions.smtp_config, + from_name=from_name, + to_email=to_email, + to_name=to_name, + subject=subject, + content=content, + ) except Exception as e: logger.exception(e) diff --git a/account/tests.py b/account/tests.py index c727f03ae..69c618156 100644 --- a/account/tests.py +++ b/account/tests.py @@ -18,9 +18,9 @@ class PermissionDecoratorTest(APITestCase): def setUp(self): - self.regular_user = User.objects.create(username="regular_user") - self.admin = User.objects.create(username="admin") - self.super_admin = User.objects.create(username="super_admin") + self.regular_user = User.objects.create(username='regular_user') + self.admin = User.objects.create(username='admin') + self.super_admin = User.objects.create(username='super_admin') self.request = mock.MagicMock() self.request.user.is_authenticated = mock.MagicMock() @@ -36,57 +36,59 @@ def test_super_admin_required(self): class DuplicateUserCheckAPITest(APITestCase): def setUp(self): - user = self.create_user("test", "test123", login=False) - user.email = "test@test.com" + user = self.create_user('test', 'test123', login=False) + user.email = 'test@test.com' user.save() - self.url = self.reverse("check_username_or_email") + self.url = self.reverse('check_username_or_email') def test_duplicate_username(self): - resp = self.client.post(self.url, data={"username": "test"}) - data = resp.data["data"] - self.assertEqual(data["username"], True) - resp = self.client.post(self.url, data={"username": "Test"}) - self.assertEqual(resp.data["data"]["username"], True) + resp = self.client.post(self.url, data={'username': 'test'}) + data = resp.data['data'] + self.assertEqual(data['username'], True) + resp = self.client.post(self.url, data={'username': 'Test'}) + self.assertEqual(resp.data['data']['username'], True) def test_ok_username(self): - resp = self.client.post(self.url, data={"username": "test1"}) - data = resp.data["data"] - self.assertFalse(data["username"]) + resp = self.client.post(self.url, data={'username': 'test1'}) + data = resp.data['data'] + self.assertFalse(data['username']) def test_duplicate_email(self): - resp = self.client.post(self.url, data={"email": "test@test.com"}) - self.assertEqual(resp.data["data"]["email"], True) - resp = self.client.post(self.url, data={"email": "Test@Test.com"}) - self.assertTrue(resp.data["data"]["email"]) + resp = self.client.post(self.url, data={'email': 'test@test.com'}) + self.assertEqual(resp.data['data']['email'], True) + resp = self.client.post(self.url, data={'email': 'Test@Test.com'}) + self.assertTrue(resp.data['data']['email']) def test_ok_email(self): - resp = self.client.post(self.url, data={"email": "aa@test.com"}) - self.assertFalse(resp.data["data"]["email"]) + resp = self.client.post(self.url, data={'email': 'aa@test.com'}) + self.assertFalse(resp.data['data']['email']) class TFARequiredCheckAPITest(APITestCase): def setUp(self): - self.url = self.reverse("tfa_required_check") - self.create_user("test", "test123", login=False) + self.url = self.reverse('tfa_required_check') + self.create_user('test', 'test123', login=False) def test_not_required_tfa(self): - resp = self.client.post(self.url, data={"username": "test"}) + resp = self.client.post(self.url, data={'username': 'test'}) self.assertSuccess(resp) - self.assertEqual(resp.data["data"]["result"], False) + self.assertEqual(resp.data['data']['result'], False) def test_required_tfa(self): user = User.objects.first() user.two_factor_auth = True user.save() - resp = self.client.post(self.url, data={"username": "test"}) - self.assertEqual(resp.data["data"]["result"], True) + resp = self.client.post(self.url, data={'username': 'test'}) + self.assertEqual(resp.data['data']['result'], True) class UserLoginAPITest(APITestCase): def setUp(self): - self.username = self.password = "test" - self.user = self.create_user(username=self.username, password=self.password, login=False) - self.login_url = self.reverse("user_login_api") + self.username = self.password = 'test' + self.user = self.create_user( + username=self.username, password=self.password, login=False + ) + self.login_url = self.reverse('user_login_api') def _set_tfa(self): self.user.two_factor_auth = True @@ -96,23 +98,31 @@ def _set_tfa(self): return tfa_token def test_login_with_correct_info(self): - response = self.client.post(self.login_url, - data={"username": self.username, "password": self.password}) - self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) + response = self.client.post( + self.login_url, data={'username': self.username, 'password': self.password} + ) + self.assertDictEqual(response.data, {'error': None, 'data': 'Succeeded'}) user = auth.get_user(self.client) self.assertTrue(user.is_authenticated) def test_login_with_correct_info_upper_username(self): - resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password}) - self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"}) + resp = self.client.post( + self.login_url, + data={'username': self.username.upper(), 'password': self.password}, + ) + self.assertDictEqual(resp.data, {'error': None, 'data': 'Succeeded'}) user = auth.get_user(self.client) self.assertTrue(user.is_authenticated) def test_login_with_wrong_info(self): - response = self.client.post(self.login_url, - data={"username": self.username, "password": "invalid_password"}) - self.assertDictEqual(response.data, {"error": "error", "data": "Invalid username or password"}) + response = self.client.post( + self.login_url, + data={'username': self.username, 'password': 'invalid_password'}, + ) + self.assertDictEqual( + response.data, {'error': 'error', 'data': 'Invalid username or password'} + ) user = auth.get_user(self.client) self.assertFalse(user.is_authenticated) @@ -121,33 +131,44 @@ def test_tfa_login(self): token = self._set_tfa() code = OtpAuth(token).totp() if len(str(code)) < 6: - code = (6 - len(str(code))) * "0" + str(code) - response = self.client.post(self.login_url, - data={"username": self.username, - "password": self.password, - "tfa_code": code}) - self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) + code = (6 - len(str(code))) * '0' + str(code) + response = self.client.post( + self.login_url, + data={ + 'username': self.username, + 'password': self.password, + 'tfa_code': code, + }, + ) + self.assertDictEqual(response.data, {'error': None, 'data': 'Succeeded'}) user = auth.get_user(self.client) self.assertTrue(user.is_authenticated) def test_tfa_login_wrong_code(self): self._set_tfa() - response = self.client.post(self.login_url, - data={"username": self.username, - "password": self.password, - "tfa_code": "qqqqqq"}) - self.assertDictEqual(response.data, {"error": "error", "data": "Invalid two factor verification code"}) + response = self.client.post( + self.login_url, + data={ + 'username': self.username, + 'password': self.password, + 'tfa_code': 'qqqqqq', + }, + ) + self.assertDictEqual( + response.data, + {'error': 'error', 'data': 'Invalid two factor verification code'}, + ) user = auth.get_user(self.client) self.assertFalse(user.is_authenticated) def test_tfa_login_without_code(self): self._set_tfa() - response = self.client.post(self.login_url, - data={"username": self.username, - "password": self.password}) - self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"}) + response = self.client.post( + self.login_url, data={'username': self.username, 'password': self.password} + ) + self.assertDictEqual(response.data, {'error': 'error', 'data': 'tfa_required'}) user = auth.get_user(self.client) self.assertFalse(user.is_authenticated) @@ -155,16 +176,19 @@ def test_tfa_login_without_code(self): def test_user_disabled(self): self.user.is_disabled = True self.user.save() - resp = self.client.post(self.login_url, data={"username": self.username, - "password": self.password}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"}) + resp = self.client.post( + self.login_url, data={'username': self.username, 'password': self.password} + ) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Your account has been disabled'} + ) class CaptchaTest(APITestCase): def _set_captcha(self, session): captcha = rand_str(4) - session["_django_captcha_key"] = captcha - session["_django_captcha_expires_time"] = int(time.time()) + 30 + session['_django_captcha_key'] = captcha + session['_django_captcha_expires_time'] = int(time.time()) + 30 session.save() return captcha @@ -172,60 +196,73 @@ def _set_captcha(self, session): class UserRegisterAPITest(CaptchaTest): def setUp(self): self.client = APIClient() - self.register_url = self.reverse("user_register_api") + self.register_url = self.reverse('user_register_api') self.captcha = rand_str(4) - self.data = {"username": "test_user", "password": "testuserpassword", - "real_name": "real_name", "email": "test@qduoj.com", - "captcha": self._set_captcha(self.client.session)} + self.data = { + 'username': 'test_user', + 'password': 'testuserpassword', + 'real_name': 'real_name', + 'email': 'test@qduoj.com', + 'captcha': self._set_captcha(self.client.session), + } def test_website_config_limit(self): SysOptions.allow_register = False resp = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"}) + self.assertDictEqual( + resp.data, + {'error': 'error', 'data': 'Register function has been disabled by admin'}, + ) def test_invalid_captcha(self): - self.data["captcha"] = "****" + self.data['captcha'] = '****' response = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(response.data, {"error": "error", "data": "Invalid captcha"}) + self.assertDictEqual( + response.data, {'error': 'error', 'data': 'Invalid captcha'} + ) - self.data.pop("captcha") + self.data.pop('captcha') response = self.client.post(self.register_url, data=self.data) - self.assertTrue(response.data["error"] is not None) + self.assertTrue(response.data['error'] is not None) def test_register_with_correct_info(self): response = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) + self.assertDictEqual(response.data, {'error': None, 'data': 'Succeeded'}) def test_username_already_exists(self): self.test_register_with_correct_info() - self.data["captcha"] = self._set_captcha(self.client.session) - self.data["email"] = "test1@qduoj.com" + self.data['captcha'] = self._set_captcha(self.client.session) + self.data['email'] = 'test1@qduoj.com' response = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(response.data, {"error": "error", "data": "Username already exists"}) + self.assertDictEqual( + response.data, {'error': 'error', 'data': 'Username already exists'} + ) def test_email_already_exists(self): self.test_register_with_correct_info() - self.data["captcha"] = self._set_captcha(self.client.session) - self.data["username"] = "test_user1" + self.data['captcha'] = self._set_captcha(self.client.session) + self.data['username'] = 'test_user1' response = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"}) + self.assertDictEqual( + response.data, {'error': 'error', 'data': 'Email already exists'} + ) class SessionManagementAPITest(APITestCase): def setUp(self): - self.create_user("test", "test123") - self.url = self.reverse("session_management_api") + self.create_user('test', 'test123') + self.url = self.reverse('session_management_api') # launch a request to provide session data - login_url = self.reverse("user_login_api") - self.client.post(login_url, data={"username": "test", "password": "test123"}) + login_url = self.reverse('user_login_api') + self.client.post(login_url, data={'username': 'test', 'password': 'test123'}) def test_get_sessions(self): resp = self.client.get(self.url) self.assertSuccess(resp) - data = resp.data["data"] + data = resp.data['data'] self.assertEqual(len(data), 1) # def test_delete_session_key(self): @@ -233,44 +270,50 @@ def test_get_sessions(self): # self.assertSuccess(resp) def test_delete_session_with_invalid_key(self): - resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa") - self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"}) + resp = self.client.delete(self.url + '?session_key=aaaaaaaaaa') + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Invalid session_key'} + ) class UserProfileAPITest(APITestCase): def setUp(self): - self.url = self.reverse("user_profile_api") + self.url = self.reverse('user_profile_api') def test_get_profile_without_login(self): resp = self.client.get(self.url) - self.assertDictEqual(resp.data, {"error": None, "data": None}) + self.assertDictEqual(resp.data, {'error': None, 'data': None}) def test_get_profile(self): - self.create_user("test", "test123") + self.create_user('test', 'test123') resp = self.client.get(self.url) self.assertSuccess(resp) def test_update_profile(self): - self.create_user("test", "test123") - update_data = {"real_name": "zemal", "submission_number": 233, "language": "en-US"} + self.create_user('test', 'test123') + update_data = { + 'real_name': 'zemal', + 'submission_number': 233, + 'language': 'en-US', + } resp = self.client.put(self.url, data=update_data) self.assertSuccess(resp) - data = resp.data["data"] - self.assertEqual(data["real_name"], "zemal") - self.assertEqual(data["submission_number"], 0) - self.assertEqual(data["language"], "en-US") + data = resp.data['data'] + self.assertEqual(data['real_name'], 'zemal') + self.assertEqual(data['submission_number'], 0) + self.assertEqual(data['language'], 'en-US') class TwoFactorAuthAPITest(APITestCase): def setUp(self): - self.url = self.reverse("two_factor_auth_api") - self.create_user("test", "test123") + self.url = self.reverse('two_factor_auth_api') + self.create_user('test', 'test123') def _get_tfa_code(self): user = User.objects.first() code = OtpAuth(user.tfa_token).totp() if len(str(code)) < 6: - code = (6 - len(str(code))) * "0" + str(code) + code = (6 - len(str(code))) * '0' + str(code) return code def test_get_image(self): @@ -279,43 +322,46 @@ def test_get_image(self): def test_open_tfa_with_invalid_code(self): self.test_get_image() - resp = self.client.post(self.url, data={"code": "000000"}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"}) + resp = self.client.post(self.url, data={'code': '000000'}) + self.assertDictEqual(resp.data, {'error': 'error', 'data': 'Invalid code'}) def test_open_tfa_with_correct_code(self): self.test_get_image() code = self._get_tfa_code() - resp = self.client.post(self.url, data={"code": code}) + resp = self.client.post(self.url, data={'code': code}) self.assertSuccess(resp) user = User.objects.first() self.assertEqual(user.two_factor_auth, True) def test_close_tfa_with_invalid_code(self): self.test_open_tfa_with_correct_code() - resp = self.client.post(self.url, data={"code": "000000"}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"}) + resp = self.client.post(self.url, data={'code': '000000'}) + self.assertDictEqual(resp.data, {'error': 'error', 'data': 'Invalid code'}) def test_close_tfa_with_correct_code(self): self.test_open_tfa_with_correct_code() code = self._get_tfa_code() - resp = self.client.put(self.url, data={"code": code}) + resp = self.client.put(self.url, data={'code': code}) self.assertSuccess(resp) user = User.objects.first() self.assertEqual(user.two_factor_auth, False) -@mock.patch("account.views.oj.send_email_async.send") +@mock.patch('account.views.oj.send_email_async.send') class ApplyResetPasswordAPITest(CaptchaTest): def setUp(self): - self.create_user("test", "test123", login=False) + self.create_user('test', 'test123', login=False) user = User.objects.first() - user.email = "test@oj.com" + user.email = 'test@oj.com' user.save() - self.url = self.reverse("apply_reset_password_api") - self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} + self.url = self.reverse('apply_reset_password_api') + self.data = { + 'email': 'test@oj.com', + 'captcha': self._set_captcha(self.client.session), + } def _refresh_captcha(self): - self.data["captcha"] = self._set_captcha(self.client.session) + self.data['captcha'] = self._set_captcha(self.client.session) def test_apply_reset_password(self, send_email_send): resp = self.client.post(self.url, data=self.data) @@ -327,7 +373,13 @@ def test_apply_reset_password_twice_in_20_mins(self, send_email_send): send_email_send.reset_mock() self._refresh_captcha() resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"}) + self.assertDictEqual( + resp.data, + { + 'error': 'error', + 'data': 'You can only reset password once per 20 minutes', + }, + ) send_email_send.assert_not_called() def test_apply_reset_password_again_after_20_mins(self, send_email_send): @@ -341,120 +393,146 @@ def test_apply_reset_password_again_after_20_mins(self, send_email_send): class ResetPasswordAPITest(CaptchaTest): def setUp(self): - self.create_user("test", "test123", login=False) - self.url = self.reverse("reset_password_api") + self.create_user('test', 'test123', login=False) + self.url = self.reverse('reset_password_api') user = User.objects.first() - user.reset_password_token = "online_judge?" + user.reset_password_token = 'online_judge?' user.reset_password_token_expire_time = now() + timedelta(minutes=20) user.save() - self.data = {"token": user.reset_password_token, - "captcha": self._set_captcha(self.client.session), - "password": "test456"} + self.data = { + 'token': user.reset_password_token, + 'captcha': self._set_captcha(self.client.session), + 'password': 'test456', + } def test_reset_password_with_correct_token(self): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) - self.assertTrue(self.client.login(username="test", password="test456")) + self.assertTrue(self.client.login(username='test', password='test456')) def test_reset_password_with_invalid_token(self): - self.data["token"] = "aaaaaaaaaaa" + self.data['token'] = 'aaaaaaaaaaa' resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"}) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Token does not exist'} + ) def test_reset_password_with_expired_token(self): user = User.objects.first() user.reset_password_token_expire_time = now() - timedelta(seconds=30) user.save() resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"}) + self.assertDictEqual(resp.data, {'error': 'error', 'data': 'Token has expired'}) class UserChangeEmailAPITest(APITestCase): def setUp(self): - self.url = self.reverse("user_change_email_api") - self.user = self.create_user("test", "test123") - self.new_mail = "test@oj.com" - self.data = {"password": "test123", "new_email": self.new_mail} + self.url = self.reverse('user_change_email_api') + self.user = self.create_user('test', 'test123') + self.new_mail = 'test@oj.com' + self.data = {'password': 'test123', 'new_email': self.new_mail} def test_change_email_success(self): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) def test_wrong_password(self): - self.data["password"] = "aaaa" + self.data['password'] = 'aaaa' resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) + self.assertDictEqual(resp.data, {'error': 'error', 'data': 'Wrong password'}) def test_duplicate_email(self): - u = self.create_user("aa", "bb", login=False) + u = self.create_user('aa', 'bb', login=False) u.email = self.new_mail u.save() resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "The email is owned by other account"}) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'The email is owned by other account'} + ) class UserChangePasswordAPITest(APITestCase): def setUp(self): - self.url = self.reverse("user_change_password_api") + self.url = self.reverse('user_change_password_api') # Create user at first - self.username = "test_user" - self.old_password = "testuserpassword" - self.new_password = "new_password" - self.user = self.create_user(username=self.username, password=self.old_password, login=False) + self.username = 'test_user' + self.old_password = 'testuserpassword' + self.new_password = 'new_password' + self.user = self.create_user( + username=self.username, password=self.old_password, login=False + ) - self.data = {"old_password": self.old_password, "new_password": self.new_password} + self.data = { + 'old_password': self.old_password, + 'new_password': self.new_password, + } def _get_tfa_code(self): user = User.objects.first() code = OtpAuth(user.tfa_token).totp() if len(str(code)) < 6: - code = (6 - len(str(code))) * "0" + str(code) + code = (6 - len(str(code))) * '0' + str(code) return code def test_login_required(self): response = self.client.post(self.url, data=self.data) - self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"}) + self.assertEqual( + response.data, {'error': 'permission-denied', 'data': 'Please login first'} + ) def test_valid_ola_password(self): - self.assertTrue(self.client.login(username=self.username, password=self.old_password)) + self.assertTrue( + self.client.login(username=self.username, password=self.old_password) + ) response = self.client.post(self.url, data=self.data) - self.assertEqual(response.data, {"error": None, "data": "Succeeded"}) - self.assertTrue(self.client.login(username=self.username, password=self.new_password)) + self.assertEqual(response.data, {'error': None, 'data': 'Succeeded'}) + self.assertTrue( + self.client.login(username=self.username, password=self.new_password) + ) def test_invalid_old_password(self): - self.assertTrue(self.client.login(username=self.username, password=self.old_password)) - self.data["old_password"] = "invalid" + self.assertTrue( + self.client.login(username=self.username, password=self.old_password) + ) + self.data['old_password'] = 'invalid' response = self.client.post(self.url, data=self.data) - self.assertEqual(response.data, {"error": "error", "data": "Invalid old password"}) + self.assertEqual( + response.data, {'error': 'error', 'data': 'Invalid old password'} + ) def test_tfa_code_required(self): self.user.two_factor_auth = True - self.user.tfa_token = "tfa_token" + self.user.tfa_token = 'tfa_token' self.user.save() - self.assertTrue(self.client.login(username=self.username, password=self.old_password)) - self.data["tfa_code"] = rand_str(6) + self.assertTrue( + self.client.login(username=self.username, password=self.old_password) + ) + self.data['tfa_code'] = rand_str(6) resp = self.client.post(self.url, data=self.data) - self.assertEqual(resp.data, {"error": "error", "data": "Invalid two factor verification code"}) + self.assertEqual( + resp.data, + {'error': 'error', 'data': 'Invalid two factor verification code'}, + ) - self.data["tfa_code"] = self._get_tfa_code() + self.data['tfa_code'] = self._get_tfa_code() resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) class UserRankAPITest(APITestCase): def setUp(self): - self.url = self.reverse("user_rank_api") - self.create_user("test1", "test123", login=False) - self.create_user("test2", "test123", login=False) - test1 = User.objects.get(username="test1") + self.url = self.reverse('user_rank_api') + self.create_user('test1', 'test123', login=False) + self.create_user('test2', 'test123', login=False) + test1 = User.objects.get(username='test1') profile1 = test1.userprofile profile1.submission_number = 10 profile1.accepted_number = 10 profile1.total_score = 240 profile1.save() - test2 = User.objects.get(username="test2") + test2 = User.objects.get(username='test2') profile2 = test2.userprofile profile2.submission_number = 15 profile2.accepted_number = 10 @@ -462,34 +540,34 @@ def setUp(self): profile2.save() def test_get_acm_rank(self): - resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) + resp = self.client.get(self.url, data={'rule': ContestRuleType.ACM}) self.assertSuccess(resp) - data = resp.data["data"]["results"] - self.assertEqual(data[0]["user"]["username"], "test1") - self.assertEqual(data[1]["user"]["username"], "test2") + data = resp.data['data']['results'] + self.assertEqual(data[0]['user']['username'], 'test1') + self.assertEqual(data[1]['user']['username'], 'test2') def test_get_oi_rank(self): - resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) + resp = self.client.get(self.url, data={'rule': ContestRuleType.OI}) self.assertSuccess(resp) - data = resp.data["data"]["results"] - self.assertEqual(data[0]["user"]["username"], "test2") - self.assertEqual(data[1]["user"]["username"], "test1") + data = resp.data['data']['results'] + self.assertEqual(data[0]['user']['username'], 'test2') + self.assertEqual(data[1]['user']['username'], 'test1') def test_admin_role_filted(self): - self.create_admin("admin", "admin123") - admin = User.objects.get(username="admin") + self.create_admin('admin', 'admin123') + admin = User.objects.get(username='admin') profile1 = admin.userprofile profile1.submission_number = 20 profile1.accepted_number = 5 profile1.total_score = 300 profile1.save() - resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) + resp = self.client.get(self.url, data={'rule': ContestRuleType.ACM}) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]), 2) + self.assertEqual(len(resp.data['data']), 2) - resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) + resp = self.client.get(self.url, data={'rule': ContestRuleType.OI}) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]), 2) + self.assertEqual(len(resp.data['data']), 2) class ProfileProblemDisplayIDRefreshAPITest(APITestCase): @@ -500,13 +578,22 @@ def setUp(self): class AdminUserTest(APITestCase): def setUp(self): self.user = self.create_super_admin(login=True) - self.username = self.password = "test" - self.regular_user = self.create_user(username=self.username, password=self.password, login=False) - self.url = self.reverse("user_admin_api") - self.data = {"id": self.regular_user.id, "username": self.username, "real_name": "test_name", - "email": "test@qq.com", "admin_type": AdminType.REGULAR_USER, - "problem_permission": ProblemPermission.OWN, "open_api": True, - "two_factor_auth": False, "is_disabled": False} + self.username = self.password = 'test' + self.regular_user = self.create_user( + username=self.username, password=self.password, login=False + ) + self.url = self.reverse('user_admin_api') + self.data = { + 'id': self.regular_user.id, + 'username': self.username, + 'real_name': 'test_name', + 'email': 'test@qq.com', + 'admin_type': AdminType.REGULAR_USER, + 'problem_permission': ProblemPermission.OWN, + 'open_api': True, + 'two_factor_auth': False, + 'is_disabled': False, + } def test_user_list(self): response = self.client.get(self.url) @@ -515,20 +602,20 @@ def test_user_list(self): def test_edit_user_successfully(self): response = self.client.put(self.url, data=self.data) self.assertSuccess(response) - resp_data = response.data["data"] - self.assertEqual(resp_data["username"], self.username) - self.assertEqual(resp_data["email"], "test@qq.com") - self.assertEqual(resp_data["open_api"], True) - self.assertEqual(resp_data["two_factor_auth"], False) - self.assertEqual(resp_data["is_disabled"], False) - self.assertEqual(resp_data["problem_permission"], ProblemPermission.NONE) + resp_data = response.data['data'] + self.assertEqual(resp_data['username'], self.username) + self.assertEqual(resp_data['email'], 'test@qq.com') + self.assertEqual(resp_data['open_api'], True) + self.assertEqual(resp_data['two_factor_auth'], False) + self.assertEqual(resp_data['is_disabled'], False) + self.assertEqual(resp_data['problem_permission'], ProblemPermission.NONE) - self.assertTrue(self.regular_user.check_password("test")) + self.assertTrue(self.regular_user.check_password('test')) def test_edit_user_password(self): data = self.data - new_password = "testpassword" - data["password"] = new_password + new_password = 'testpassword' + data['password'] = new_password response = self.client.put(self.url, data=data) self.assertSuccess(response) user = User.objects.get(id=self.regular_user.id) @@ -538,64 +625,72 @@ def test_edit_user_password(self): def test_edit_user_tfa(self): data = self.data self.assertIsNone(self.regular_user.tfa_token) - data["two_factor_auth"] = True + data['two_factor_auth'] = True response = self.client.put(self.url, data=data) self.assertSuccess(response) - resp_data = response.data["data"] + resp_data = response.data['data'] # if `tfa_token` is None, a new value will be generated - self.assertTrue(resp_data["two_factor_auth"]) + self.assertTrue(resp_data['two_factor_auth']) token = User.objects.get(id=self.regular_user.id).tfa_token self.assertIsNotNone(token) response = self.client.put(self.url, data=data) self.assertSuccess(response) - resp_data = response.data["data"] + resp_data = response.data['data'] # if `tfa_token` is not None, the value is not changed - self.assertTrue(resp_data["two_factor_auth"]) + self.assertTrue(resp_data['two_factor_auth']) self.assertEqual(User.objects.get(id=self.regular_user.id).tfa_token, token) def test_edit_user_openapi(self): data = self.data self.assertIsNone(self.regular_user.open_api_appkey) - data["open_api"] = True + data['open_api'] = True response = self.client.put(self.url, data=data) self.assertSuccess(response) - resp_data = response.data["data"] + resp_data = response.data['data'] # if `open_api_appkey` is None, a new value will be generated - self.assertTrue(resp_data["open_api"]) + self.assertTrue(resp_data['open_api']) key = User.objects.get(id=self.regular_user.id).open_api_appkey self.assertIsNotNone(key) response = self.client.put(self.url, data=data) self.assertSuccess(response) - resp_data = response.data["data"] + resp_data = response.data['data'] # if `openapi_app_key` is not None, the value is not changed - self.assertTrue(resp_data["open_api"]) + self.assertTrue(resp_data['open_api']) self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key) def test_import_users(self): - data = {"users": [["user1", "pass1", "eami1@e.com", "user1"], - ["user2", "pass3", "eamil3@e.com", "user2"]] - } + data = { + 'users': [ + ['user1', 'pass1', 'eami1@e.com', 'user1'], + ['user2', 'pass3', 'eamil3@e.com', 'user2'], + ] + } resp = self.client.post(self.url, data) self.assertSuccess(resp) # successfully created 2 users self.assertEqual(User.objects.all().count(), 4) def test_import_duplicate_user(self): - data = {"users": [["user1", "pass1", "eami1@e.com", "user1"], - ["user1", "pass1", "eami1@e.com", "user1"]] - } + data = { + 'users': [ + ['user1', 'pass1', 'eami1@e.com', 'user1'], + ['user1', 'pass1', 'eami1@e.com', 'user1'], + ] + } resp = self.client.post(self.url, data) - self.assertFailed(resp, "DETAIL: Key (username)=(user1) already exists.") + self.assertFailed(resp, 'DETAIL: Key (username)=(user1) already exists.') # no user is created self.assertEqual(User.objects.all().count(), 2) def test_delete_users(self): self.test_import_users() - user_ids = User.objects.filter(username__in=["user1", "user2"]).values_list("id", flat=True) - user_ids = ",".join([str(id) for id in user_ids]) - resp = self.client.delete(self.url + "?id=" + user_ids) + user_ids = User.objects.filter(username__in=['user1', 'user2']).values_list( + 'id', flat=True + ) + user_ids = ','.join([str(id) for id in user_ids]) + resp = self.client.delete(self.url + '?id=' + user_ids) self.assertSuccess(resp) self.assertEqual(User.objects.all().count(), 2) @@ -603,27 +698,33 @@ def test_delete_users(self): class GenerateUserAPITest(APITestCase): def setUp(self): self.create_super_admin() - self.url = self.reverse("generate_user_api") + self.url = self.reverse('generate_user_api') self.data = { - "number_from": 100, "number_to": 105, - "prefix": "pre", "suffix": "suf", - "default_email": "test@test.com", - "password_length": 8 + 'number_from': 100, + 'number_to': 105, + 'prefix': 'pre', + 'suffix': 'suf', + 'default_email': 'test@test.com', + 'password_length': 8, } def test_error_case(self): data = deepcopy(self.data) - data["prefix"] = "t" * 16 - data["suffix"] = "s" * 14 + data['prefix'] = 't' * 16 + data['suffix'] = 's' * 14 resp = self.client.post(self.url, data=data) - self.assertEqual(resp.data["data"], "Username should not more than 32 characters") + self.assertEqual( + resp.data['data'], 'Username should not more than 32 characters' + ) data2 = deepcopy(self.data) - data2["number_from"] = 106 + data2['number_from'] = 106 resp = self.client.post(self.url, data=data2) - self.assertEqual(resp.data["data"], "Start number must be lower than end number") + self.assertEqual( + resp.data['data'], 'Start number must be lower than end number' + ) - @mock.patch("account.views.admin.xlsxwriter.Workbook") + @mock.patch('account.views.admin.xlsxwriter.Workbook') def test_generate_user_success(self, mock_workbook): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) @@ -633,7 +734,7 @@ def test_generate_user_success(self, mock_workbook): class OpenAPIAppkeyAPITest(APITestCase): def setUp(self): self.user = self.create_super_admin() - self.url = self.reverse("open_api_appkey_api") + self.url = self.reverse('open_api_appkey_api') def test_reset_appkey(self): resp = self.client.post(self.url, data={}) @@ -643,4 +744,7 @@ def test_reset_appkey(self): self.user.save() resp = self.client.post(self.url, data={}) self.assertSuccess(resp) - self.assertEqual(resp.data["data"]["appkey"], User.objects.get(username=self.user.username).open_api_appkey) + self.assertEqual( + resp.data['data']['appkey'], + User.objects.get(username=self.user.username).open_api_appkey, + ) diff --git a/account/urls/admin.py b/account/urls/admin.py index 5826ae2e6..07a503331 100644 --- a/account/urls/admin.py +++ b/account/urls/admin.py @@ -3,6 +3,6 @@ from ..views.admin import UserAdminAPI, GenerateUserAPI urlpatterns = [ - url(r"^user/?$", UserAdminAPI.as_view(), name="user_admin_api"), - url(r"^generate_user/?$", GenerateUserAPI.as_view(), name="generate_user_api"), + url(r'^user/?$', UserAdminAPI.as_view(), name='user_admin_api'), + url(r'^generate_user/?$', GenerateUserAPI.as_view(), name='generate_user_api'), ] diff --git a/account/urls/oj.py b/account/urls/oj.py index a2cebe024..eabedaccc 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -1,31 +1,62 @@ from django.conf.urls import url -from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, - UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI, - UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, - AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, - UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI, - ProfileProblemDisplayIDRefreshAPI, OpenAPIAppkeyAPI, SSOAPI) +from ..views.oj import ( + ApplyResetPasswordAPI, + ResetPasswordAPI, + UserChangePasswordAPI, + UserRegisterAPI, + UserChangeEmailAPI, + UserLoginAPI, + UserLogoutAPI, + UsernameOrEmailCheck, + AvatarUploadAPI, + TwoFactorAuthAPI, + UserProfileAPI, + UserRankAPI, + CheckTFARequiredAPI, + SessionManagementAPI, + ProfileProblemDisplayIDRefreshAPI, + OpenAPIAppkeyAPI, + SSOAPI, +) from utils.captcha.views import CaptchaAPIView urlpatterns = [ - url(r"^login/?$", UserLoginAPI.as_view(), name="user_login_api"), - url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"), - url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), - url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), - url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email_api"), - url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), - url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"), - url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), - url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), - url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), - url(r"^profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view(), name="display_id_fresh"), - url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), - url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"), - url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), - url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), - url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api"), - url(r"^open_api_appkey/?$", OpenAPIAppkeyAPI.as_view(), name="open_api_appkey_api"), - url(r"^sso?$", SSOAPI.as_view(), name="sso_api") + url(r'^login/?$', UserLoginAPI.as_view(), name='user_login_api'), + url(r'^logout/?$', UserLogoutAPI.as_view(), name='user_logout_api'), + url(r'^register/?$', UserRegisterAPI.as_view(), name='user_register_api'), + url( + r'^change_password/?$', + UserChangePasswordAPI.as_view(), + name='user_change_password_api', + ), + url( + r'^change_email/?$', UserChangeEmailAPI.as_view(), name='user_change_email_api' + ), + url( + r'^apply_reset_password/?$', + ApplyResetPasswordAPI.as_view(), + name='apply_reset_password_api', + ), + url(r'^reset_password/?$', ResetPasswordAPI.as_view(), name='reset_password_api'), + url(r'^captcha/?$', CaptchaAPIView.as_view(), name='show_captcha'), + url( + r'^check_username_or_email', + UsernameOrEmailCheck.as_view(), + name='check_username_or_email', + ), + url(r'^profile/?$', UserProfileAPI.as_view(), name='user_profile_api'), + url( + r'^profile/fresh_display_id', + ProfileProblemDisplayIDRefreshAPI.as_view(), + name='display_id_fresh', + ), + url(r'^upload_avatar/?$', AvatarUploadAPI.as_view(), name='avatar_upload_api'), + url(r'^tfa_required/?$', CheckTFARequiredAPI.as_view(), name='tfa_required_check'), + url(r'^two_factor_auth/?$', TwoFactorAuthAPI.as_view(), name='two_factor_auth_api'), + url(r'^user_rank/?$', UserRankAPI.as_view(), name='user_rank_api'), + url(r'^sessions/?$', SessionManagementAPI.as_view(), name='session_management_api'), + url(r'^open_api_appkey/?$', OpenAPIAppkeyAPI.as_view(), name='open_api_appkey_api'), + url(r'^sso?$', SSOAPI.as_view(), name='sso_api'), ] diff --git a/account/views/admin.py b/account/views/admin.py index b207db929..b65b8e2d1 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -13,7 +13,11 @@ from ..decorators import super_admin_required from ..models import AdminType, ProblemPermission, User, UserProfile -from ..serializers import EditUserSerializer, UserAdminSerializer, GenerateUserSerializer +from ..serializers import ( + EditUserSerializer, + UserAdminSerializer, + GenerateUserSerializer, +) from ..serializers import ImportUserSeralizer @@ -24,24 +28,35 @@ def post(self, request): """ Import User """ - data = request.data["users"] + data = request.data['users'] user_list = [] for user_data in data: if len(user_data) != 4 or len(user_data[0]) > 32: return self.error(f"Error occurred while processing data '{user_data}'") - user_list.append(User(username=user_data[0], password=make_password(user_data[1]), email=user_data[2])) + user_list.append( + User( + username=user_data[0], + password=make_password(user_data[1]), + email=user_data[2], + ) + ) try: with transaction.atomic(): ret = User.objects.bulk_create(user_list) - UserProfile.objects.bulk_create([UserProfile(user=ret[i], real_name=data[i][3]) for i in range(len(ret))]) + UserProfile.objects.bulk_create( + [ + UserProfile(user=ret[i], real_name=data[i][3]) + for i in range(len(ret)) + ] + ) return self.success() except IntegrityError as e: # Extract detail from exception message # duplicate key value violates unique constraint "user_username_key" # DETAIL: Key (username)=(root11) already exists. - return self.error(str(e).split("\n")[1]) + return self.error(str(e).split('\n')[1]) @validate_serializer(EditUserSerializer) @super_admin_required @@ -51,52 +66,62 @@ def put(self, request): """ data = request.data try: - user = User.objects.get(id=data["id"]) + user = User.objects.get(id=data['id']) except User.DoesNotExist: - return self.error("User does not exist") - if User.objects.filter(username=data["username"].lower()).exclude(id=user.id).exists(): - return self.error("Username already exists") - if User.objects.filter(email=data["email"].lower()).exclude(id=user.id).exists(): - return self.error("Email already exists") + return self.error('User does not exist') + if ( + User.objects.filter(username=data['username'].lower()) + .exclude(id=user.id) + .exists() + ): + return self.error('Username already exists') + if ( + User.objects.filter(email=data['email'].lower()) + .exclude(id=user.id) + .exists() + ): + return self.error('Email already exists') pre_username = user.username - user.username = data["username"].lower() - user.email = data["email"].lower() - user.admin_type = data["admin_type"] - user.is_disabled = data["is_disabled"] - - if data["admin_type"] == AdminType.ADMIN: - user.problem_permission = data["problem_permission"] - elif data["admin_type"] == AdminType.SUPER_ADMIN: + user.username = data['username'].lower() + user.email = data['email'].lower() + user.admin_type = data['admin_type'] + user.is_disabled = data['is_disabled'] + + if data['admin_type'] == AdminType.ADMIN: + user.problem_permission = data['problem_permission'] + elif data['admin_type'] == AdminType.SUPER_ADMIN: user.problem_permission = ProblemPermission.ALL else: user.problem_permission = ProblemPermission.NONE - if data["password"]: - user.set_password(data["password"]) + if data['password']: + user.set_password(data['password']) - if data["open_api"]: + if data['open_api']: # Avoid reset user appkey after saving changes if not user.open_api: user.open_api_appkey = rand_str() else: user.open_api_appkey = None - user.open_api = data["open_api"] + user.open_api = data['open_api'] - if data["two_factor_auth"]: + if data['two_factor_auth']: # Avoid reset user tfa_token after saving changes if not user.two_factor_auth: user.tfa_token = rand_str() else: user.tfa_token = None - user.two_factor_auth = data["two_factor_auth"] + user.two_factor_auth = data['two_factor_auth'] user.save() if pre_username != user.username: - Submission.objects.filter(username=pre_username).update(username=user.username) + Submission.objects.filter(username=pre_username).update( + username=user.username + ) - UserProfile.objects.filter(user=user).update(real_name=data["real_name"]) + UserProfile.objects.filter(user=user).update(real_name=data['real_name']) return self.success(UserAdminSerializer(user).data) @super_admin_required @@ -104,31 +129,33 @@ def get(self, request): """ User list api / Get user by id """ - user_id = request.GET.get("id") + user_id = request.GET.get('id') if user_id: try: user = User.objects.get(id=user_id) except User.DoesNotExist: - return self.error("User does not exist") + return self.error('User does not exist') return self.success(UserAdminSerializer(user).data) - user = User.objects.all().order_by("-create_time") + user = User.objects.all().order_by('-create_time') - keyword = request.GET.get("keyword", None) + keyword = request.GET.get('keyword', None) if keyword: - user = user.filter(Q(username__icontains=keyword) | - Q(userprofile__real_name__icontains=keyword) | - Q(email__icontains=keyword)) + user = user.filter( + Q(username__icontains=keyword) + | Q(userprofile__real_name__icontains=keyword) + | Q(email__icontains=keyword) + ) return self.success(self.paginate_data(request, user, UserAdminSerializer)) @super_admin_required def delete(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Invalid Parameter, id is required") - ids = id.split(",") + return self.error('Invalid Parameter, id is required') + ids = id.split(',') if str(request.user.id) in ids: - return self.error("Current user can not be deleted") + return self.error('Current user can not be deleted') User.objects.filter(id__in=ids).delete() return self.success() @@ -139,20 +166,20 @@ def get(self, request): """ download users excel """ - file_id = request.GET.get("file_id") + file_id = request.GET.get('file_id') if not file_id: - return self.error("Invalid Parameter, file_id is required") - if not re.match(r"^[a-zA-Z0-9]+$", file_id): - return self.error("Illegal file_id") - file_path = f"/tmp/{file_id}.xlsx" + return self.error('Invalid Parameter, file_id is required') + if not re.match(r'^[a-zA-Z0-9]+$', file_id): + return self.error('Illegal file_id') + file_path = f'/tmp/{file_id}.xlsx' if not os.path.isfile(file_path): - return self.error("File does not exist") - with open(file_path, "rb") as f: + return self.error('File does not exist') + with open(file_path, 'rb') as f: raw_data = f.read() os.remove(file_path) response = HttpResponse(raw_data) - response["Content-Disposition"] = "attachment; filename=users.xlsx" - response["Content-Type"] = "application/xlsx" + response['Content-Disposition'] = 'attachment; filename=users.xlsx' + response['Content-Type'] = 'application/xlsx' return response @validate_serializer(GenerateUserSerializer) @@ -162,41 +189,47 @@ def post(self, request): Generate User """ data = request.data - number_max_length = max(len(str(data["number_from"])), len(str(data["number_to"]))) - if number_max_length + len(data["prefix"]) + len(data["suffix"]) > 32: - return self.error("Username should not more than 32 characters") - if data["number_from"] > data["number_to"]: - return self.error("Start number must be lower than end number") + number_max_length = max( + len(str(data['number_from'])), len(str(data['number_to'])) + ) + if number_max_length + len(data['prefix']) + len(data['suffix']) > 32: + return self.error('Username should not more than 32 characters') + if data['number_from'] > data['number_to']: + return self.error('Start number must be lower than end number') file_id = rand_str(8) - filename = f"/tmp/{file_id}.xlsx" + filename = f'/tmp/{file_id}.xlsx' workbook = xlsxwriter.Workbook(filename) worksheet = workbook.add_worksheet() - worksheet.set_column("A:B", 20) - worksheet.write("A1", "Username") - worksheet.write("B1", "Password") + worksheet.set_column('A:B', 20) + worksheet.write('A1', 'Username') + worksheet.write('B1', 'Password') i = 1 user_list = [] - for number in range(data["number_from"], data["number_to"] + 1): - raw_password = rand_str(data["password_length"]) - user = User(username=f"{data['prefix']}{number}{data['suffix']}", password=make_password(raw_password)) + for number in range(data['number_from'], data['number_to'] + 1): + raw_password = rand_str(data['password_length']) + user = User( + username=f"{data['prefix']}{number}{data['suffix']}", + password=make_password(raw_password), + ) user.raw_password = raw_password user_list.append(user) try: with transaction.atomic(): - ret = User.objects.bulk_create(user_list) - UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret]) + UserProfile.objects.bulk_create( + [UserProfile(user=user) for user in ret] + ) for item in user_list: worksheet.write_string(i, 0, item.username) worksheet.write_string(i, 1, item.raw_password) i += 1 workbook.close() - return self.success({"file_id": file_id}) + return self.success({'file_id': file_id}) except IntegrityError as e: # Extract detail from exception message # duplicate key value violates unique constraint "user_username_key" # DETAIL: Key (username)=(root11) already exists. - return self.error(str(e).split("\n")[1]) + return self.error(str(e).split('\n')[1]) diff --git a/account/views/oj.py b/account/views/oj.py index 2a3c2e09f..6b1e43e24 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -19,12 +19,23 @@ from utils.shortcuts import rand_str, img2base64, datetime2str from ..decorators import login_required from ..models import User, UserProfile, AdminType -from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, - UserChangePasswordSerializer, UserLoginSerializer, - UserRegisterSerializer, UsernameOrEmailCheckSerializer, - RankInfoSerializer, UserChangeEmailSerializer, SSOSerializer) -from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer, - EditUserProfileSerializer, ImageUploadForm) +from ..serializers import ( + ApplyResetPasswordSerializer, + ResetPasswordSerializer, + UserChangePasswordSerializer, + UserLoginSerializer, + UserRegisterSerializer, + UsernameOrEmailCheckSerializer, + RankInfoSerializer, + UserChangeEmailSerializer, + SSOSerializer, +) +from ..serializers import ( + TwoFactorAuthCodeSerializer, + UserProfileSerializer, + EditUserProfileSerializer, + ImageUploadForm, +) from ..tasks import send_email_async @@ -38,7 +49,7 @@ def get(self, request, **kwargs): if not user.is_authenticated: return self.success() show_real_name = False - username = request.GET.get("username") + username = request.GET.get('username') try: if username: user = User.objects.get(username=username, is_disabled=False) @@ -47,8 +58,10 @@ def get(self, request, **kwargs): # api返回的是自己的信息,可以返real_name show_real_name = True except User.DoesNotExist: - return self.error("User does not exist") - return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data) + return self.error('User does not exist') + return self.success( + UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data + ) @validate_serializer(EditUserProfileSerializer) @login_required @@ -58,7 +71,9 @@ def put(self, request): for k, v in data.items(): setattr(user_profile, k, v) user_profile.save() - return self.success(UserProfileSerializer(user_profile, show_real_name=True).data) + return self.success( + UserProfileSerializer(user_profile, show_real_name=True).data + ) class AvatarUploadAPI(APIView): @@ -68,24 +83,24 @@ class AvatarUploadAPI(APIView): def post(self, request): form = ImageUploadForm(request.POST, request.FILES) if form.is_valid(): - avatar = form.cleaned_data["image"] + avatar = form.cleaned_data['image'] else: - return self.error("Invalid file content") + return self.error('Invalid file content') if avatar.size > 2 * 1024 * 1024: - return self.error("Picture is too large") + return self.error('Picture is too large') suffix = os.path.splitext(avatar.name)[-1].lower() - if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: - return self.error("Unsupported file format") + if suffix not in ['.gif', '.jpg', '.jpeg', '.bmp', '.png']: + return self.error('Unsupported file format') name = rand_str(10) + suffix - with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img: + with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), 'wb') as img: for chunk in avatar: img.write(chunk) user_profile = request.user.userprofile - user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}" + user_profile.avatar = f'{settings.AVATAR_URI_PREFIX}/{name}' user_profile.save() - return self.success("Succeeded") + return self.success('Succeeded') class TwoFactorAuthAPI(APIView): @@ -96,13 +111,17 @@ def get(self, request): """ user = request.user if user.two_factor_auth: - return self.error("2FA is already turned on") + return self.error('2FA is already turned on') token = rand_str() user.tfa_token = token user.save() - label = f"{SysOptions.website_name_shortcut}:{user.username}" - image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name.replace(" ", ""))) + label = f'{SysOptions.website_name_shortcut}:{user.username}' + image = qrcode.make( + OtpAuth(token).to_uri( + 'totp', label, SysOptions.website_name.replace(' ', '') + ) + ) return self.success(img2base64(image)) @login_required @@ -111,28 +130,28 @@ def post(self, request): """ Open 2FA """ - code = request.data["code"] + code = request.data['code'] user = request.user if OtpAuth(user.tfa_token).valid_totp(code): user.two_factor_auth = True user.save() - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Invalid code") + return self.error('Invalid code') @login_required @validate_serializer(TwoFactorAuthCodeSerializer) def put(self, request): - code = request.data["code"] + code = request.data['code'] user = request.user if not user.two_factor_auth: - return self.error("2FA is already turned off") + return self.error('2FA is already turned off') if OtpAuth(user.tfa_token).valid_totp(code): user.two_factor_auth = False user.save() - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Invalid code") + return self.error('Invalid code') class CheckTFARequiredAPI(APIView): @@ -143,13 +162,13 @@ def post(self, request): """ data = request.data result = False - if data.get("username"): + if data.get('username'): try: - user = User.objects.get(username=data["username"]) + user = User.objects.get(username=data['username']) result = user.two_factor_auth except User.DoesNotExist: pass - return self.success({"result": result}) + return self.success({'result': result}) class UserLoginAPI(APIView): @@ -159,26 +178,26 @@ def post(self, request): User login api """ data = request.data - user = auth.authenticate(username=data["username"], password=data["password"]) + user = auth.authenticate(username=data['username'], password=data['password']) # None is returned if username or password is wrong if user: if user.is_disabled: - return self.error("Your account has been disabled") + return self.error('Your account has been disabled') if not user.two_factor_auth: auth.login(request, user) - return self.success("Succeeded") + return self.success('Succeeded') # `tfa_code` not in post data - if user.two_factor_auth and "tfa_code" not in data: - return self.error("tfa_required") + if user.two_factor_auth and 'tfa_code' not in data: + return self.error('tfa_required') - if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): + if OtpAuth(user.tfa_token).valid_totp(data['tfa_code']): auth.login(request, user) - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Invalid two factor verification code") + return self.error('Invalid two factor verification code') else: - return self.error("Invalid username or password") + return self.error('Invalid username or password') class UserLogoutAPI(APIView): @@ -195,14 +214,13 @@ def post(self, request): """ data = request.data # True means already exist. - result = { - "username": False, - "email": False - } - if data.get("username"): - result["username"] = User.objects.filter(username=data["username"].lower()).exists() - if data.get("email"): - result["email"] = User.objects.filter(email=data["email"].lower()).exists() + result = {'username': False, 'email': False} + if data.get('username'): + result['username'] = User.objects.filter( + username=data['username'].lower() + ).exists() + if data.get('email'): + result['email'] = User.objects.filter(email=data['email'].lower()).exists() return self.success(result) @@ -213,23 +231,23 @@ def post(self, request): User register api """ if not SysOptions.allow_register: - return self.error("Register function has been disabled by admin") + return self.error('Register function has been disabled by admin') data = request.data - data["username"] = data["username"].lower() - data["email"] = data["email"].lower() + data['username'] = data['username'].lower() + data['email'] = data['email'].lower() captcha = Captcha(request) - if not captcha.check(data["captcha"]): - return self.error("Invalid captcha") - if User.objects.filter(username=data["username"]).exists(): - return self.error("Username already exists") - if User.objects.filter(email=data["email"]).exists(): - return self.error("Email already exists") - user = User.objects.create(username=data["username"], email=data["email"]) - user.set_password(data["password"]) + if not captcha.check(data['captcha']): + return self.error('Invalid captcha') + if User.objects.filter(username=data['username']).exists(): + return self.error('Username already exists') + if User.objects.filter(email=data['email']).exists(): + return self.error('Email already exists') + user = User.objects.create(username=data['username'], email=data['email']) + user.set_password(data['password']) user.save() UserProfile.objects.create(user=user) - return self.success("Succeeded") + return self.success('Succeeded') class UserChangeEmailAPI(APIView): @@ -237,21 +255,23 @@ class UserChangeEmailAPI(APIView): @login_required def post(self, request): data = request.data - user = auth.authenticate(username=request.user.username, password=data["password"]) + user = auth.authenticate( + username=request.user.username, password=data['password'] + ) if user: if user.two_factor_auth: - if "tfa_code" not in data: - return self.error("tfa_required") - if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): - return self.error("Invalid two factor verification code") - data["new_email"] = data["new_email"].lower() - if User.objects.filter(email=data["new_email"]).exists(): - return self.error("The email is owned by other account") - user.email = data["new_email"] + if 'tfa_code' not in data: + return self.error('tfa_required') + if not OtpAuth(user.tfa_token).valid_totp(data['tfa_code']): + return self.error('Invalid two factor verification code') + data['new_email'] = data['new_email'].lower() + if User.objects.filter(email=data['new_email']).exists(): + return self.error('The email is owned by other account') + user.email = data['new_email'] user.save() - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Wrong password") + return self.error('Wrong password') class UserChangePasswordAPI(APIView): @@ -263,51 +283,57 @@ def post(self, request): """ data = request.data username = request.user.username - user = auth.authenticate(username=username, password=data["old_password"]) + user = auth.authenticate(username=username, password=data['old_password']) if user: if user.two_factor_auth: - if "tfa_code" not in data: - return self.error("tfa_required") - if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): - return self.error("Invalid two factor verification code") - user.set_password(data["new_password"]) + if 'tfa_code' not in data: + return self.error('tfa_required') + if not OtpAuth(user.tfa_token).valid_totp(data['tfa_code']): + return self.error('Invalid two factor verification code') + user.set_password(data['new_password']) user.save() - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Invalid old password") + return self.error('Invalid old password') class ApplyResetPasswordAPI(APIView): @validate_serializer(ApplyResetPasswordSerializer) def post(self, request): if request.user.is_authenticated: - return self.error("You have already logged in, are you kidding me? ") + return self.error('You have already logged in, are you kidding me? ') data = request.data captcha = Captcha(request) - if not captcha.check(data["captcha"]): - return self.error("Invalid captcha") + if not captcha.check(data['captcha']): + return self.error('Invalid captcha') try: - user = User.objects.get(email__iexact=data["email"]) + user = User.objects.get(email__iexact=data['email']) except User.DoesNotExist: - return self.error("User does not exist") - if user.reset_password_token_expire_time and 0 < int( - (user.reset_password_token_expire_time - now()).total_seconds()) < 20 * 60: - return self.error("You can only reset password once per 20 minutes") + return self.error('User does not exist') + if ( + user.reset_password_token_expire_time + and 0 + < int((user.reset_password_token_expire_time - now()).total_seconds()) + < 20 * 60 + ): + return self.error('You can only reset password once per 20 minutes') user.reset_password_token = rand_str() user.reset_password_token_expire_time = now() + timedelta(minutes=20) user.save() render_data = { - "username": user.username, - "website_name": SysOptions.website_name, - "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" + 'username': user.username, + 'website_name': SysOptions.website_name, + 'link': f'{SysOptions.website_base_url}/reset-password/{user.reset_password_token}', } - email_html = render_to_string("reset_password_email.html", render_data) - send_email_async.send(from_name=SysOptions.website_name_shortcut, - to_email=user.email, - to_name=user.username, - subject="Reset your password", - content=email_html) - return self.success("Succeeded") + email_html = render_to_string('reset_password_email.html', render_data) + send_email_async.send( + from_name=SysOptions.website_name_shortcut, + to_email=user.email, + to_name=user.username, + subject='Reset your password', + content=email_html, + ) + return self.success('Succeeded') class ResetPasswordAPI(APIView): @@ -315,19 +341,19 @@ class ResetPasswordAPI(APIView): def post(self, request): data = request.data captcha = Captcha(request) - if not captcha.check(data["captcha"]): - return self.error("Invalid captcha") + if not captcha.check(data['captcha']): + return self.error('Invalid captcha') try: - user = User.objects.get(reset_password_token=data["token"]) + user = User.objects.get(reset_password_token=data['token']) except User.DoesNotExist: - return self.error("Token does not exist") + return self.error('Token does not exist') if user.reset_password_token_expire_time < now(): - return self.error("Token has expired") + return self.error('Token has expired') user.reset_password_token = None user.two_factor_auth = False - user.set_password(data["password"]) + user.set_password(data['password']) user.save() - return self.success("Succeeded") + return self.success('Succeeded') class SessionManagementAPI(APIView): @@ -349,11 +375,11 @@ def get(self, request): s = {} if current_session == key: - s["current_session"] = True - s["ip"] = session["ip"] - s["user_agent"] = session["user_agent"] - s["last_activity"] = datetime2str(session["last_activity"]) - s["session_key"] = key + s['current_session'] = True + s['ip'] = session['ip'] + s['user_agent'] = session['user_agent'] + s['last_activity'] = datetime2str(session['last_activity']) + s['session_key'] = key result.append(s) if modified: request.user.save() @@ -361,29 +387,32 @@ def get(self, request): @login_required def delete(self, request): - session_key = request.GET.get("session_key") + session_key = request.GET.get('session_key') if not session_key: - return self.error("Parameter Error") + return self.error('Parameter Error') request.session.delete(session_key) if session_key in request.user.session_keys: request.user.session_keys.remove(session_key) request.user.save() - return self.success("Succeeded") + return self.success('Succeeded') else: - return self.error("Invalid session_key") + return self.error('Invalid session_key') class UserRankAPI(APIView): def get(self, request): - rule_type = request.GET.get("rule") + rule_type = request.GET.get('rule') if rule_type not in ContestRuleType.choices(): rule_type = ContestRuleType.ACM - profiles = UserProfile.objects.filter(user__admin_type=AdminType.REGULAR_USER, user__is_disabled=False) \ - .select_related("user") + profiles = UserProfile.objects.filter( + user__admin_type=AdminType.REGULAR_USER, user__is_disabled=False + ).select_related('user') if rule_type == ContestRuleType.ACM: - profiles = profiles.filter(submission_number__gt=0).order_by("-accepted_number", "submission_number") + profiles = profiles.filter(submission_number__gt=0).order_by( + '-accepted_number', 'submission_number' + ) else: - profiles = profiles.filter(total_score__gt=0).order_by("-total_score") + profiles = profiles.filter(total_score__gt=0).order_by('-total_score') return self.success(self.paginate_data(request, profiles, RankInfoSerializer)) @@ -391,18 +420,20 @@ class ProfileProblemDisplayIDRefreshAPI(APIView): @login_required def get(self, request): profile = request.user.userprofile - acm_problems = profile.acm_problems_status.get("problems", {}) - oi_problems = profile.oi_problems_status.get("problems", {}) + acm_problems = profile.acm_problems_status.get('problems', {}) + oi_problems = profile.oi_problems_status.get('problems', {}) ids = list(acm_problems.keys()) + list(oi_problems.keys()) if not ids: return self.success() - display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True) + display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list( + '_id', flat=True + ) id_map = dict(zip(ids, display_ids)) for k, v in acm_problems.items(): - v["_id"] = id_map[k] + v['_id'] = id_map[k] for k, v in oi_problems.items(): - v["_id"] = id_map[k] - profile.save(update_fields=["acm_problems_status", "oi_problems_status"]) + v['_id'] = id_map[k] + profile.save(update_fields=['acm_problems_status', 'oi_problems_status']) return self.success() @@ -411,11 +442,11 @@ class OpenAPIAppkeyAPI(APIView): def post(self, request): user = request.user if not user.open_api: - return self.error("OpenAPI function is truned off for you") + return self.error('OpenAPI function is truned off for you') api_appkey = rand_str() user.open_api_appkey = api_appkey user.save() - return self.success({"appkey": api_appkey}) + return self.success({'appkey': api_appkey}) class SSOAPI(CSRFExemptAPIView): @@ -424,13 +455,19 @@ def get(self, request): token = rand_str() request.user.auth_token = token request.user.save() - return self.success({"token": token}) + return self.success({'token': token}) @method_decorator(csrf_exempt) @validate_serializer(SSOSerializer) def post(self, request): try: - user = User.objects.get(auth_token=request.data["token"]) + user = User.objects.get(auth_token=request.data['token']) except User.DoesNotExist: - return self.error("User does not exist") - return self.success({"username": user.username, "avatar": user.userprofile.avatar, "admin_type": user.admin_type}) + return self.error('User does not exist') + return self.success( + { + 'username': user.username, + 'avatar': user.userprofile.avatar, + 'admin_type': user.admin_type, + } + ) diff --git a/announcement/migrations/0001_initial.py b/announcement/migrations/0001_initial.py index cd92e223e..9d9665043 100644 --- a/announcement/migrations/0001_initial.py +++ b/announcement/migrations/0001_initial.py @@ -10,7 +10,6 @@ class Migration(migrations.Migration): - initial = True dependencies = [ @@ -21,13 +20,27 @@ class Migration(migrations.Migration): migrations.CreateModel( name='Announcement', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=50)), ('content', utils.models.RichTextField()), ('create_time', models.DateTimeField(auto_now_add=True)), ('last_update_time', models.DateTimeField(auto_now=True)), ('visible', models.BooleanField(default=True)), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'announcement', diff --git a/announcement/migrations/0002_auto_20171011_1214.py b/announcement/migrations/0002_auto_20171011_1214.py index e2d5abe32..8ea97ca00 100644 --- a/announcement/migrations/0002_auto_20171011_1214.py +++ b/announcement/migrations/0002_auto_20171011_1214.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('announcement', '0001_initial'), ] diff --git a/announcement/migrations/0003_auto_20180501_0436.py b/announcement/migrations/0003_auto_20180501_0436.py index 6ecd492f5..ddc4cb65b 100644 --- a/announcement/migrations/0003_auto_20180501_0436.py +++ b/announcement/migrations/0003_auto_20180501_0436.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('announcement', '0002_auto_20171011_1214'), ] diff --git a/announcement/models.py b/announcement/models.py index 441c4c184..75db80ad5 100644 --- a/announcement/models.py +++ b/announcement/models.py @@ -14,5 +14,5 @@ class Announcement(models.Model): visible = models.BooleanField(default=True) class Meta: - db_table = "announcement" - ordering = ("-create_time",) + db_table = 'announcement' + ordering = ('-create_time',) diff --git a/announcement/serializers.py b/announcement/serializers.py index 1edd1d9cc..22f258073 100644 --- a/announcement/serializers.py +++ b/announcement/serializers.py @@ -15,7 +15,7 @@ class AnnouncementSerializer(serializers.ModelSerializer): class Meta: model = Announcement - fields = "__all__" + fields = '__all__' class EditAnnouncementSerializer(serializers.Serializer): diff --git a/announcement/tests.py b/announcement/tests.py index 98caa1c89..ab25fc72c 100644 --- a/announcement/tests.py +++ b/announcement/tests.py @@ -6,14 +6,16 @@ class AnnouncementAdminTest(APITestCase): def setUp(self): self.user = self.create_super_admin() - self.url = self.reverse("announcement_admin_api") + self.url = self.reverse('announcement_admin_api') def test_announcement_list(self): response = self.client.get(self.url) self.assertSuccess(response) def create_announcement(self): - return self.client.post(self.url, data={"title": "test", "content": "test", "visible": True}) + return self.client.post( + self.url, data={'title': 'test', 'content': 'test', 'visible': True} + ) def test_create_announcement(self): resp = self.create_announcement() @@ -21,18 +23,22 @@ def test_create_announcement(self): return resp def test_edit_announcement(self): - data = {"id": self.create_announcement().data["data"]["id"], "title": "ahaha", "content": "test content", - "visible": False} + data = { + 'id': self.create_announcement().data['data']['id'], + 'title': 'ahaha', + 'content': 'test content', + 'visible': False, + } resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - resp_data = resp.data["data"] - self.assertEqual(resp_data["title"], "ahaha") - self.assertEqual(resp_data["content"], "test content") - self.assertEqual(resp_data["visible"], False) + resp_data = resp.data['data'] + self.assertEqual(resp_data['title'], 'ahaha') + self.assertEqual(resp_data['content'], 'test content') + self.assertEqual(resp_data['visible'], False) def test_delete_announcement(self): - id = self.test_create_announcement().data["data"]["id"] - resp = self.client.delete(self.url + "?id=" + str(id)) + id = self.test_create_announcement().data['data']['id'] + resp = self.client.delete(self.url + '?id=' + str(id)) self.assertSuccess(resp) self.assertFalse(Announcement.objects.filter(id=id).exists()) @@ -40,8 +46,10 @@ def test_delete_announcement(self): class AnnouncementAPITest(APITestCase): def setUp(self): self.user = self.create_super_admin() - Announcement.objects.create(title="title", content="content", visible=True, created_by=self.user) - self.url = self.reverse("announcement_api") + Announcement.objects.create( + title='title', content='content', visible=True, created_by=self.user + ) + self.url = self.reverse('announcement_api') def test_get_announcement_list(self): resp = self.client.get(self.url) diff --git a/announcement/urls/admin.py b/announcement/urls/admin.py index 09673e638..6aa00e3a5 100644 --- a/announcement/urls/admin.py +++ b/announcement/urls/admin.py @@ -3,5 +3,9 @@ from ..views.admin import AnnouncementAdminAPI urlpatterns = [ - url(r"^announcement/?$", AnnouncementAdminAPI.as_view(), name="announcement_admin_api"), + url( + r'^announcement/?$', + AnnouncementAdminAPI.as_view(), + name='announcement_admin_api', + ), ] diff --git a/announcement/urls/oj.py b/announcement/urls/oj.py index 67178b012..b4b0db322 100644 --- a/announcement/urls/oj.py +++ b/announcement/urls/oj.py @@ -3,5 +3,5 @@ from ..views.oj import AnnouncementAPI urlpatterns = [ - url(r"^announcement/?$", AnnouncementAPI.as_view(), name="announcement_api"), + url(r'^announcement/?$', AnnouncementAPI.as_view(), name='announcement_api'), ] diff --git a/announcement/views/admin.py b/announcement/views/admin.py index 58d35780f..930c7beb9 100644 --- a/announcement/views/admin.py +++ b/announcement/views/admin.py @@ -2,8 +2,11 @@ from utils.api import APIView, validate_serializer from announcement.models import Announcement -from announcement.serializers import (AnnouncementSerializer, CreateAnnouncementSerializer, - EditAnnouncementSerializer) +from announcement.serializers import ( + AnnouncementSerializer, + CreateAnnouncementSerializer, + EditAnnouncementSerializer, +) class AnnouncementAdminAPI(APIView): @@ -14,10 +17,12 @@ def post(self, request): publish announcement """ data = request.data - announcement = Announcement.objects.create(title=data["title"], - content=data["content"], - created_by=request.user, - visible=data["visible"]) + announcement = Announcement.objects.create( + title=data['title'], + content=data['content'], + created_by=request.user, + visible=data['visible'], + ) return self.success(AnnouncementSerializer(announcement).data) @validate_serializer(EditAnnouncementSerializer) @@ -28,9 +33,9 @@ def put(self, request): """ data = request.data try: - announcement = Announcement.objects.get(id=data.pop("id")) + announcement = Announcement.objects.get(id=data.pop('id')) except Announcement.DoesNotExist: - return self.error("Announcement does not exist") + return self.error('Announcement does not exist') for k, v in data.items(): setattr(announcement, k, v) @@ -43,20 +48,22 @@ def get(self, request): """ get announcement list / get one announcement """ - announcement_id = request.GET.get("id") + announcement_id = request.GET.get('id') if announcement_id: try: announcement = Announcement.objects.get(id=announcement_id) return self.success(AnnouncementSerializer(announcement).data) except Announcement.DoesNotExist: - return self.error("Announcement does not exist") - announcement = Announcement.objects.all().order_by("-create_time") - if request.GET.get("visible") == "true": + return self.error('Announcement does not exist') + announcement = Announcement.objects.all().order_by('-create_time') + if request.GET.get('visible') == 'true': announcement = announcement.filter(visible=True) - return self.success(self.paginate_data(request, announcement, AnnouncementSerializer)) + return self.success( + self.paginate_data(request, announcement, AnnouncementSerializer) + ) @super_admin_required def delete(self, request): - if request.GET.get("id"): - Announcement.objects.filter(id=request.GET["id"]).delete() + if request.GET.get('id'): + Announcement.objects.filter(id=request.GET['id']).delete() return self.success() diff --git a/announcement/views/oj.py b/announcement/views/oj.py index 1176c368b..f5eff95fc 100644 --- a/announcement/views/oj.py +++ b/announcement/views/oj.py @@ -7,4 +7,6 @@ class AnnouncementAPI(APIView): def get(self, request): announcements = Announcement.objects.filter(visible=True) - return self.success(self.paginate_data(request, announcements, AnnouncementSerializer)) + return self.success( + self.paginate_data(request, announcements, AnnouncementSerializer) + ) diff --git a/conf/migrations/0001_initial.py b/conf/migrations/0001_initial.py index 46c7dc39b..f254ef292 100644 --- a/conf/migrations/0001_initial.py +++ b/conf/migrations/0001_initial.py @@ -6,17 +6,23 @@ class Migration(migrations.Migration): - initial = True - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( name='JudgeServer', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('hostname', models.CharField(max_length=64)), ('ip', models.CharField(blank=True, max_length=32, null=True)), ('judger_version', models.CharField(max_length=24)), @@ -26,7 +32,10 @@ class Migration(migrations.Migration): ('last_heartbeat', models.DateTimeField()), ('create_time', models.DateTimeField(auto_now_add=True)), ('task_number', models.IntegerField(default=0)), - ('service_url', models.CharField(blank=True, max_length=128, null=True)), + ( + 'service_url', + models.CharField(blank=True, max_length=128, null=True), + ), ], options={ 'db_table': 'judge_server', @@ -35,7 +44,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='JudgeServerToken', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('token', models.CharField(max_length=32)), ], options={ @@ -45,7 +62,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='SMTPConfig', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('server', models.CharField(max_length=128)), ('port', models.IntegerField(default=25)), ('email', models.CharField(max_length=128)), @@ -59,8 +84,19 @@ class Migration(migrations.Migration): migrations.CreateModel( name='WebsiteConfig', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('base_url', models.CharField(default='http://127.0.0.1', max_length=128)), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), + ( + 'base_url', + models.CharField(default='http://127.0.0.1', max_length=128), + ), ('name', models.CharField(default='Online Judge', max_length=32)), ('name_shortcut', models.CharField(default='oj', max_length=32)), ('footer', models.TextField(default='Online Judge Footer')), diff --git a/conf/migrations/0002_auto_20171011_1214.py b/conf/migrations/0002_auto_20171011_1214.py index ef355b508..ad8bd12f9 100644 --- a/conf/migrations/0002_auto_20171011_1214.py +++ b/conf/migrations/0002_auto_20171011_1214.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('conf', '0001_initial'), ] diff --git a/conf/migrations/0003_judgeserver_is_disabled.py b/conf/migrations/0003_judgeserver_is_disabled.py index 6a571c010..33741c470 100644 --- a/conf/migrations/0003_judgeserver_is_disabled.py +++ b/conf/migrations/0003_judgeserver_is_disabled.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('conf', '0002_auto_20171011_1214'), ] diff --git a/conf/migrations/0004_auto_20180501_0436.py b/conf/migrations/0004_auto_20180501_0436.py index 00ede1c23..71e32d919 100644 --- a/conf/migrations/0004_auto_20180501_0436.py +++ b/conf/migrations/0004_auto_20180501_0436.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('conf', '0003_judgeserver_is_disabled'), ] diff --git a/conf/models.py b/conf/models.py index e0d1c93c5..927df9311 100644 --- a/conf/models.py +++ b/conf/models.py @@ -19,8 +19,8 @@ class JudgeServer(models.Model): def status(self): # 增加一秒延时,提高对网络环境的适应性 if (timezone.now() - self.last_heartbeat).total_seconds() > 6: - return "abnormal" - return "normal" + return 'abnormal' + return 'normal' class Meta: - db_table = "judge_server" + db_table = 'judge_server' diff --git a/conf/serializers.py b/conf/serializers.py index e50dfa21e..4c8257454 100644 --- a/conf/serializers.py +++ b/conf/serializers.py @@ -7,7 +7,9 @@ class EditSMTPConfigSerializer(serializers.Serializer): server = serializers.CharField(max_length=128) port = serializers.IntegerField(default=25) email = serializers.CharField(max_length=256) - password = serializers.CharField(max_length=128, required=False, allow_null=True, allow_blank=True) + password = serializers.CharField( + max_length=128, required=False, allow_null=True, allow_blank=True + ) tls = serializers.BooleanField() @@ -33,7 +35,7 @@ class JudgeServerSerializer(serializers.ModelSerializer): class Meta: model = JudgeServer - fields = "__all__" + fields = '__all__' class JudgeServerHeartbeatSerializer(serializers.Serializer): @@ -42,7 +44,7 @@ class JudgeServerHeartbeatSerializer(serializers.Serializer): cpu_core = serializers.IntegerField(min_value=1) memory = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100) - action = serializers.ChoiceField(choices=("heartbeat", )) + action = serializers.ChoiceField(choices=('heartbeat',)) service_url = serializers.CharField(max_length=256) diff --git a/conf/tests.py b/conf/tests.py index 50d23f732..de51337f3 100644 --- a/conf/tests.py +++ b/conf/tests.py @@ -1,7 +1,6 @@ import hashlib from unittest import mock -from django.conf import settings from django.utils import timezone from options.options import SysOptions @@ -12,43 +11,62 @@ class SMTPConfigTest(APITestCase): def setUp(self): self.user = self.create_super_admin() - self.url = self.reverse("smtp_admin_api") - self.password = "testtest" + self.url = self.reverse('smtp_admin_api') + self.password = 'testtest' def test_create_smtp_config(self): - data = {"server": "smtp.test.com", "email": "test@test.com", "port": 465, - "tls": True, "password": self.password} + data = { + 'server': 'smtp.test.com', + 'email': 'test@test.com', + 'port': 465, + 'tls': True, + 'password': self.password, + } resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - self.assertTrue("password" not in resp.data) + self.assertTrue('password' not in resp.data) return resp def test_edit_without_password(self): self.test_create_smtp_config() - data = {"server": "smtp1.test.com", "email": "test2@test.com", "port": 465, - "tls": True} + data = { + 'server': 'smtp1.test.com', + 'email': 'test2@test.com', + 'port': 465, + 'tls': True, + } resp = self.client.put(self.url, data=data) self.assertSuccess(resp) def test_edit_without_password1(self): self.test_create_smtp_config() - data = {"server": "smtp.test.com", "email": "test@test.com", "port": 465, - "tls": True, "password": ""} + data = { + 'server': 'smtp.test.com', + 'email': 'test@test.com', + 'port': 465, + 'tls': True, + 'password': '', + } resp = self.client.put(self.url, data=data) self.assertSuccess(resp) def test_edit_with_password(self): self.test_create_smtp_config() - data = {"server": "smtp1.test.com", "email": "test2@test.com", "port": 465, - "tls": True, "password": "newpassword"} + data = { + 'server': 'smtp1.test.com', + 'email': 'test2@test.com', + 'port': 465, + 'tls': True, + 'password': 'newpassword', + } resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - @mock.patch("conf.views.send_email") + @mock.patch('conf.views.send_email') def test_test_smtp(self, mocked_send_email): - url = self.reverse("smtp_test_api") + url = self.reverse('smtp_test_api') self.test_create_smtp_config() - resp = self.client.post(url, data={"email": "test@test.com"}) + resp = self.client.post(url, data={'email': 'test@test.com'}) self.assertSuccess(resp) mocked_send_email.assert_called_once() @@ -56,101 +74,135 @@ def test_test_smtp(self, mocked_send_email): class WebsiteConfigAPITest(APITestCase): def test_create_website_config(self): self.create_super_admin() - url = self.reverse("website_config_api") - data = {"website_base_url": "http://test.com", "website_name": "test name", - "website_name_shortcut": "test oj", "website_footer": "test", - "allow_register": True, "submission_list_show_all": False} + url = self.reverse('website_config_api') + data = { + 'website_base_url': 'http://test.com', + 'website_name': 'test name', + 'website_name_shortcut': 'test oj', + 'website_footer': 'test', + 'allow_register': True, + 'submission_list_show_all': False, + } resp = self.client.post(url, data=data) self.assertSuccess(resp) def test_edit_website_config(self): self.create_super_admin() - url = self.reverse("website_config_api") - data = {"website_base_url": "http://test.com", "website_name": "test name", - "website_name_shortcut": "test oj", "website_footer": "", - "allow_register": True, "submission_list_show_all": False} + url = self.reverse('website_config_api') + data = { + 'website_base_url': 'http://test.com', + 'website_name': 'test name', + 'website_name_shortcut': 'test oj', + 'website_footer': '', + 'allow_register': True, + 'submission_list_show_all': False, + } resp = self.client.post(url, data=data) self.assertSuccess(resp) self.assertEqual(SysOptions.website_footer, '') def test_get_website_config(self): # do not need to login - url = self.reverse("website_info_api") + url = self.reverse('website_info_api') resp = self.client.get(url) self.assertSuccess(resp) class JudgeServerHeartbeatTest(APITestCase): def setUp(self): - self.url = self.reverse("judge_server_heartbeat_api") - self.data = {"hostname": "testhostname", "judger_version": "1.0.4", "cpu_core": 4, - "cpu": 90.5, "memory": 80.3, "action": "heartbeat", "service_url": "http://127.0.0.1"} - self.token = "test" - self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() + self.url = self.reverse('judge_server_heartbeat_api') + self.data = { + 'hostname': 'testhostname', + 'judger_version': '1.0.4', + 'cpu_core': 4, + 'cpu': 90.5, + 'memory': 80.3, + 'action': 'heartbeat', + 'service_url': 'http://127.0.0.1', + } + self.token = 'test' + self.hashed_token = hashlib.sha256(self.token.encode('utf-8')).hexdigest() SysOptions.judge_server_token = self.token - self.headers = {"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token, 'X-Forwarded-For': "1.2.3.4"} + self.headers = { + 'HTTP_X_JUDGE_SERVER_TOKEN': self.hashed_token, + 'X-Forwarded-For': '1.2.3.4', + } def test_new_heartbeat(self): resp = self.client.post(self.url, data=self.data, **self.headers) self.assertSuccess(resp) server = JudgeServer.objects.first() - self.assertEqual(server.ip, "127.0.0.1") + self.assertEqual(server.ip, '127.0.0.1') def test_update_heartbeat(self): self.test_new_heartbeat() data = self.data - data["judger_version"] = "2.0.0" + data['judger_version'] = '2.0.0' resp = self.client.post(self.url, data=data, **self.headers) self.assertSuccess(resp) - self.assertEqual(JudgeServer.objects.get(hostname=self.data["hostname"]).judger_version, data["judger_version"]) + self.assertEqual( + JudgeServer.objects.get(hostname=self.data['hostname']).judger_version, + data['judger_version'], + ) class JudgeServerAPITest(APITestCase): def setUp(self): - self.server = JudgeServer.objects.create(**{"hostname": "testhostname", "judger_version": "1.0.4", - "cpu_core": 4, "cpu_usage": 90.5, "memory_usage": 80.3, - "last_heartbeat": timezone.now()}) - self.url = self.reverse("judge_server_api") + self.server = JudgeServer.objects.create( + **{ + 'hostname': 'testhostname', + 'judger_version': '1.0.4', + 'cpu_core': 4, + 'cpu_usage': 90.5, + 'memory_usage': 80.3, + 'last_heartbeat': timezone.now(), + } + ) + self.url = self.reverse('judge_server_api') self.create_super_admin() def test_get_judge_server(self): resp = self.client.get(self.url) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]["servers"]), 1) + self.assertEqual(len(resp.data['data']['servers']), 1) def test_delete_judge_server(self): - resp = self.client.delete(self.url + "?hostname=testhostname") + resp = self.client.delete(self.url + '?hostname=testhostname') self.assertSuccess(resp) - self.assertFalse(JudgeServer.objects.filter(hostname="testhostname").exists()) + self.assertFalse(JudgeServer.objects.filter(hostname='testhostname').exists()) def test_disabled_judge_server(self): - resp = self.client.put(self.url, data={"is_disabled": True, "id": self.server.id}) + resp = self.client.put( + self.url, data={'is_disabled': True, 'id': self.server.id} + ) self.assertSuccess(resp) self.assertTrue(JudgeServer.objects.get(id=self.server.id).is_disabled) class LanguageListAPITest(APITestCase): def test_get_languages(self): - resp = self.client.get(self.reverse("language_list_api")) + resp = self.client.get(self.reverse('language_list_api')) self.assertSuccess(resp) class TestCasePruneAPITest(APITestCase): def setUp(self): - self.url = self.reverse("prune_test_case_api") + self.url = self.reverse('prune_test_case_api') self.create_super_admin() def test_get_isolated_test_case(self): resp = self.client.get(self.url) self.assertSuccess(resp) - @mock.patch("conf.views.TestCasePruneAPI.delete_one") - @mock.patch("conf.views.os.listdir") - @mock.patch("conf.views.Problem") + @mock.patch('conf.views.TestCasePruneAPI.delete_one') + @mock.patch('conf.views.os.listdir') + @mock.patch('conf.views.Problem') def test_delete_test_case(self, mocked_problem, mocked_listdir, mocked_delete_one): - valid_id = "1172980672983b2b49820be3a741b109" - mocked_problem.return_value = [valid_id, ] - mocked_listdir.return_value = [valid_id, ".test", "aaa"] + valid_id = '1172980672983b2b49820be3a741b109' + mocked_problem.return_value = [ + valid_id, + ] + mocked_listdir.return_value = [valid_id, '.test', 'aaa'] resp = self.client.delete(self.url) self.assertSuccess(resp) mocked_delete_one.assert_called_once_with(valid_id) @@ -158,16 +210,20 @@ def test_delete_test_case(self, mocked_problem, mocked_listdir, mocked_delete_on class ReleaseNoteAPITest(APITestCase): def setUp(self): - self.url = self.reverse("get_release_notes_api") + self.url = self.reverse('get_release_notes_api') self.create_super_admin() - self.latest_data = {"update": [ - { - "version": "2099-12-25", - "level": 1, - "title": "Update at 2099-12-25", - "details": ["test get", ] - } - ]} + self.latest_data = { + 'update': [ + { + 'version': '2099-12-25', + 'level': 1, + 'title': 'Update at 2099-12-25', + 'details': [ + 'test get', + ], + } + ] + } def test_get_versions(self): resp = self.client.get(self.url) @@ -176,10 +232,10 @@ def test_get_versions(self): class DashboardInfoAPITest(APITestCase): def setUp(self): - self.url = self.reverse("dashboard_info_api") + self.url = self.reverse('dashboard_info_api') self.create_admin() def test_get_info(self): resp = self.client.get(self.url) self.assertSuccess(resp) - self.assertEqual(resp.data["data"]["user_count"], 1) + self.assertEqual(resp.data['data']['user_count'], 1) diff --git a/conf/urls/admin.py b/conf/urls/admin.py index 8e6d293f6..7b3594c93 100644 --- a/conf/urls/admin.py +++ b/conf/urls/admin.py @@ -1,14 +1,20 @@ from django.conf.urls import url -from ..views import SMTPAPI, JudgeServerAPI, WebsiteConfigAPI, TestCasePruneAPI, SMTPTestAPI +from ..views import ( + SMTPAPI, + JudgeServerAPI, + WebsiteConfigAPI, + TestCasePruneAPI, + SMTPTestAPI, +) from ..views import ReleaseNotesAPI, DashboardInfoAPI urlpatterns = [ - url(r"^smtp/?$", SMTPAPI.as_view(), name="smtp_admin_api"), - url(r"^smtp_test/?$", SMTPTestAPI.as_view(), name="smtp_test_api"), - url(r"^website/?$", WebsiteConfigAPI.as_view(), name="website_config_api"), - url(r"^judge_server/?$", JudgeServerAPI.as_view(), name="judge_server_api"), - url(r"^prune_test_case/?$", TestCasePruneAPI.as_view(), name="prune_test_case_api"), - url(r"^versions/?$", ReleaseNotesAPI.as_view(), name="get_release_notes_api"), - url(r"^dashboard_info", DashboardInfoAPI.as_view(), name="dashboard_info_api"), + url(r'^smtp/?$', SMTPAPI.as_view(), name='smtp_admin_api'), + url(r'^smtp_test/?$', SMTPTestAPI.as_view(), name='smtp_test_api'), + url(r'^website/?$', WebsiteConfigAPI.as_view(), name='website_config_api'), + url(r'^judge_server/?$', JudgeServerAPI.as_view(), name='judge_server_api'), + url(r'^prune_test_case/?$', TestCasePruneAPI.as_view(), name='prune_test_case_api'), + url(r'^versions/?$', ReleaseNotesAPI.as_view(), name='get_release_notes_api'), + url(r'^dashboard_info', DashboardInfoAPI.as_view(), name='dashboard_info_api'), ] diff --git a/conf/urls/oj.py b/conf/urls/oj.py index 3e6d7578f..9b5289104 100644 --- a/conf/urls/oj.py +++ b/conf/urls/oj.py @@ -3,7 +3,11 @@ from ..views import JudgeServerHeartbeatAPI, LanguagesAPI, WebsiteConfigAPI urlpatterns = [ - url(r"^website/?$", WebsiteConfigAPI.as_view(), name="website_info_api"), - url(r"^judge_server_heartbeat/?$", JudgeServerHeartbeatAPI.as_view(), name="judge_server_heartbeat_api"), - url(r"^languages/?$", LanguagesAPI.as_view(), name="language_list_api") + url(r'^website/?$', WebsiteConfigAPI.as_view(), name='website_info_api'), + url( + r'^judge_server_heartbeat/?$', + JudgeServerHeartbeatAPI.as_view(), + name='judge_server_heartbeat_api', + ), + url(r'^languages/?$', LanguagesAPI.as_view(), name='language_list_api'), ] diff --git a/conf/views.py b/conf/views.py index be7374428..77d05a7c7 100644 --- a/conf/views.py +++ b/conf/views.py @@ -24,10 +24,15 @@ from utils.shortcuts import send_email, get_env from utils.xss_filter import XSSHtml from .models import JudgeServer -from .serializers import (CreateEditWebsiteConfigSerializer, - CreateSMTPConfigSerializer, EditSMTPConfigSerializer, - JudgeServerHeartbeatSerializer, - JudgeServerSerializer, TestSMTPConfigSerializer, EditJudgeServerSerializer) +from .serializers import ( + CreateEditWebsiteConfigSerializer, + CreateSMTPConfigSerializer, + EditSMTPConfigSerializer, + JudgeServerHeartbeatSerializer, + JudgeServerSerializer, + TestSMTPConfigSerializer, + EditJudgeServerSerializer, +) class SMTPAPI(APIView): @@ -36,7 +41,7 @@ def get(self, request): smtp = SysOptions.smtp_config if not smtp: return self.success(None) - smtp.pop("password") + smtp.pop('password') return self.success(smtp) @super_admin_required @@ -50,10 +55,10 @@ def post(self, request): def put(self, request): smtp = SysOptions.smtp_config data = request.data - for item in ["server", "port", "email", "tls"]: + for item in ['server', 'port', 'email', 'tls']: smtp[item] = data[item] - if "password" in data: - smtp["password"] = data["password"] + if 'password' in data: + smtp['password'] = data['password'] SysOptions.smtp_config = smtp return self.success() @@ -63,23 +68,25 @@ class SMTPTestAPI(APIView): @validate_serializer(TestSMTPConfigSerializer) def post(self, request): if not SysOptions.smtp_config: - return self.error("Please setup SMTP config at first") + return self.error('Please setup SMTP config at first') try: - send_email(smtp_config=SysOptions.smtp_config, - from_name=SysOptions.website_name_shortcut, - to_name=request.user.username, - to_email=request.data["email"], - subject="You have successfully configured SMTP", - content="You have successfully configured SMTP") + send_email( + smtp_config=SysOptions.smtp_config, + from_name=SysOptions.website_name_shortcut, + to_name=request.user.username, + to_email=request.data['email'], + subject='You have successfully configured SMTP', + content='You have successfully configured SMTP', + ) except smtplib.SMTPResponseException as e: # guess error message encoding - msg = b"Failed to send email" + msg = b'Failed to send email' try: msg = e.smtp_error # qq mail - msg = msg.decode("gbk") + msg = msg.decode('gbk') except Exception: - msg = msg.decode("utf-8", "ignore") + msg = msg.decode('utf-8', 'ignore') return self.error(msg) except Exception as e: msg = str(e) @@ -89,16 +96,24 @@ def post(self, request): class WebsiteConfigAPI(APIView): def get(self, request): - ret = {key: getattr(SysOptions, key) for key in - ["website_base_url", "website_name", "website_name_shortcut", - "website_footer", "allow_register", "submission_list_show_all"]} + ret = { + key: getattr(SysOptions, key) + for key in [ + 'website_base_url', + 'website_name', + 'website_name_shortcut', + 'website_footer', + 'allow_register', + 'submission_list_show_all', + ] + } return self.success(ret) @super_admin_required @validate_serializer(CreateEditWebsiteConfigSerializer) def post(self, request): for k, v in request.data.items(): - if k == "website_footer": + if k == 'website_footer': with XSSHtml() as parser: v = parser.clean(v) setattr(SysOptions, k, v) @@ -108,13 +123,17 @@ def post(self, request): class JudgeServerAPI(APIView): @super_admin_required def get(self, request): - servers = JudgeServer.objects.all().order_by("-last_heartbeat") - return self.success({"token": SysOptions.judge_server_token, - "servers": JudgeServerSerializer(servers, many=True).data}) + servers = JudgeServer.objects.all().order_by('-last_heartbeat') + return self.success( + { + 'token': SysOptions.judge_server_token, + 'servers': JudgeServerSerializer(servers, many=True).data, + } + ) @super_admin_required def delete(self, request): - hostname = request.GET.get("hostname") + hostname = request.GET.get('hostname') if hostname: JudgeServer.objects.filter(hostname=hostname).delete() return self.success() @@ -122,8 +141,10 @@ def delete(self, request): @validate_serializer(EditJudgeServerSerializer) @super_admin_required def put(self, request): - is_disabled = request.data.get("is_disabled", False) - JudgeServer.objects.filter(id=request.data["id"]).update(is_disabled=is_disabled) + is_disabled = request.data.get('is_disabled', False) + JudgeServer.objects.filter(id=request.data['id']).update( + is_disabled=is_disabled + ) if not is_disabled: process_pending_task() return self.success() @@ -133,30 +154,43 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): @validate_serializer(JudgeServerHeartbeatSerializer) def post(self, request): data = request.data - client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") - if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token: - return self.error("Invalid token") + client_token = request.META.get('HTTP_X_JUDGE_SERVER_TOKEN') + if ( + hashlib.sha256(SysOptions.judge_server_token.encode('utf-8')).hexdigest() + != client_token + ): + return self.error('Invalid token') try: - server = JudgeServer.objects.get(hostname=data["hostname"]) - server.judger_version = data["judger_version"] - server.cpu_core = data["cpu_core"] - server.memory_usage = data["memory"] - server.cpu_usage = data["cpu"] - server.service_url = data["service_url"] + server = JudgeServer.objects.get(hostname=data['hostname']) + server.judger_version = data['judger_version'] + server.cpu_core = data['cpu_core'] + server.memory_usage = data['memory'] + server.cpu_usage = data['cpu'] + server.service_url = data['service_url'] server.ip = request.ip server.last_heartbeat = timezone.now() - server.save(update_fields=["judger_version", "cpu_core", "memory_usage", "service_url", "ip", "last_heartbeat"]) + server.save( + update_fields=[ + 'judger_version', + 'cpu_core', + 'memory_usage', + 'service_url', + 'ip', + 'last_heartbeat', + ] + ) except JudgeServer.DoesNotExist: - JudgeServer.objects.create(hostname=data["hostname"], - judger_version=data["judger_version"], - cpu_core=data["cpu_core"], - memory_usage=data["memory"], - cpu_usage=data["cpu"], - ip=request.META["REMOTE_ADDR"], - service_url=data["service_url"], - last_heartbeat=timezone.now(), - ) + JudgeServer.objects.create( + hostname=data['hostname'], + judger_version=data['judger_version'], + cpu_core=data['cpu_core'], + memory_usage=data['memory'], + cpu_usage=data['cpu'], + ip=request.META['REMOTE_ADDR'], + service_url=data['service_url'], + last_heartbeat=timezone.now(), + ) # 新server上线 处理队列中的,防止没有新的提交而导致一直waiting process_pending_task() @@ -165,7 +199,12 @@ def post(self, request): class LanguagesAPI(APIView): def get(self, request): - return self.success({"languages": SysOptions.languages, "spj_languages": SysOptions.spj_languages}) + return self.success( + { + 'languages': SysOptions.languages, + 'spj_languages': SysOptions.spj_languages, + } + ) class TestCasePruneAPI(APIView): @@ -180,12 +219,12 @@ def get(self, request): # return an iterator for d in os.scandir(settings.TEST_CASE_DIR): if d.name in dir_to_be_removed: - ret_data.append({"id": d.name, "create_time": d.stat().st_mtime}) + ret_data.append({'id': d.name, 'create_time': d.stat().st_mtime}) return self.success(ret_data) @super_admin_required def delete(self, request): - test_case_id = request.GET.get("id") + test_case_id = request.GET.get('id') if test_case_id: self.delete_one(test_case_id) return self.success() @@ -195,9 +234,9 @@ def delete(self, request): @staticmethod def get_orphan_ids(): - db_ids = Problem.objects.all().values_list("test_case_id", flat=True) + db_ids = Problem.objects.all().values_list('test_case_id', flat=True) disk_ids = os.listdir(settings.TEST_CASE_DIR) - test_case_re = re.compile(r"^[a-zA-Z0-9]{32}$") + test_case_re = re.compile(r'^[a-zA-Z0-9]{32}$') disk_ids = filter(lambda f: test_case_re.match(f), disk_ids) return list(set(disk_ids) - set(db_ids)) @@ -211,14 +250,17 @@ def delete_one(id): class ReleaseNotesAPI(APIView): def get(self, request): try: - resp = requests.get("https://raw.githubusercontent.com/QingdaoU/OnlineJudge/master/docs/data.json?_=" + str(time.time()), - timeout=3) + resp = requests.get( + 'https://raw.githubusercontent.com/QingdaoU/OnlineJudge/master/docs/data.json?_=' + + str(time.time()), + timeout=3, + ) releases = resp.json() except (RequestException, ValueError): return self.success() - with open("docs/data.json", "r") as f: - local_version = json.load(f)["update"][0]["version"] - releases["local_version"] = local_version + with open('docs/data.json', 'r') as f: + local_version = json.load(f)['update'][0]['version'] + releases['local_version'] = local_version return self.success(releases) @@ -226,16 +268,25 @@ class DashboardInfoAPI(APIView): def get(self, request): today = datetime.today() today_submission_count = Submission.objects.filter( - create_time__gte=datetime(today.year, today.month, today.day, 0, 0, tzinfo=pytz.UTC)).count() - recent_contest_count = Contest.objects.exclude(end_time__lt=timezone.now()).count() - judge_server_count = len(list(filter(lambda x: x.status == "normal", JudgeServer.objects.all()))) - return self.success({ - "user_count": User.objects.count(), - "recent_contest_count": recent_contest_count, - "today_submission_count": today_submission_count, - "judge_server_count": judge_server_count, - "env": { - "FORCE_HTTPS": get_env("FORCE_HTTPS", default=False), - "STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default="") + create_time__gte=datetime( + today.year, today.month, today.day, 0, 0, tzinfo=pytz.UTC + ) + ).count() + recent_contest_count = Contest.objects.exclude( + end_time__lt=timezone.now() + ).count() + judge_server_count = len( + list(filter(lambda x: x.status == 'normal', JudgeServer.objects.all())) + ) + return self.success( + { + 'user_count': User.objects.count(), + 'recent_contest_count': recent_contest_count, + 'today_submission_count': today_submission_count, + 'judge_server_count': judge_server_count, + 'env': { + 'FORCE_HTTPS': get_env('FORCE_HTTPS', default=False), + 'STATIC_CDN_HOST': get_env('STATIC_CDN_HOST', default=''), + }, } - }) + ) diff --git a/contest/migrations/0001_initial.py b/contest/migrations/0001_initial.py index 255b8d0ea..d8a861dda 100644 --- a/contest/migrations/0001_initial.py +++ b/contest/migrations/0001_initial.py @@ -10,7 +10,6 @@ class Migration(migrations.Migration): - initial = True dependencies = [ @@ -22,7 +21,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='ACMContestRank', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('total_submission_number', models.IntegerField(default=0)), ('total_ac_number', models.IntegerField(default=0)), ('total_time', models.IntegerField(default=0)), @@ -35,7 +42,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='Contest', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=40)), ('description', utils.models.RichTextField()), ('real_time_rank', models.BooleanField()), @@ -46,7 +61,13 @@ class Migration(migrations.Migration): ('create_time', models.DateTimeField(auto_now_add=True)), ('last_update_time', models.DateTimeField(auto_now=True)), ('visible', models.BooleanField(default=True)), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'contest', @@ -55,12 +76,32 @@ class Migration(migrations.Migration): migrations.CreateModel( name='ContestAnnouncement', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=128)), ('content', utils.models.RichTextField()), ('create_time', models.DateTimeField(auto_now_add=True)), - ('contest', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contest.Contest')), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'contest', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), + ), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'contest_announcement', @@ -69,7 +110,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='ContestProblem', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=128)), ('description', utils.models.RichTextField()), ('input_description', utils.models.RichTextField()), @@ -85,7 +134,10 @@ class Migration(migrations.Migration): ('time_limit', models.IntegerField()), ('memory_limit', models.IntegerField()), ('spj', models.BooleanField(default=False)), - ('spj_language', models.CharField(blank=True, max_length=32, null=True)), + ( + 'spj_language', + models.CharField(blank=True, max_length=32, null=True), + ), ('spj_code', models.TextField(blank=True, null=True)), ('spj_version', models.CharField(blank=True, max_length=32, null=True)), ('rule_type', models.CharField(max_length=32)), @@ -96,8 +148,20 @@ class Migration(migrations.Migration): ('total_accepted_number', models.IntegerField(default=0)), ('sort_index', models.CharField(max_length=30)), ('is_public', models.BooleanField(default=False)), - ('contest', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contest.Contest')), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'contest', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), + ), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ('tags', models.ManyToManyField(to='problem.ProblemTag')), ], options={ @@ -107,12 +171,32 @@ class Migration(migrations.Migration): migrations.CreateModel( name='OIContestRank', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('total_submission_number', models.IntegerField(default=0)), ('total_score', models.IntegerField(default=0)), ('submission_info', jsonfield.fields.JSONField(default={})), - ('contest', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contest.Contest')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'contest', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), + ), + ( + 'user', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'oi_contest_rank', @@ -121,11 +205,15 @@ class Migration(migrations.Migration): migrations.AddField( model_name='acmcontestrank', name='contest', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'), + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='contest.Contest' + ), ), migrations.AddField( model_name='acmcontestrank', name='user', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL + ), ), ] diff --git a/contest/migrations/0002_auto_20170209_0845.py b/contest/migrations/0002_auto_20170209_0845.py index 447545f14..301118319 100644 --- a/contest/migrations/0002_auto_20170209_0845.py +++ b/contest/migrations/0002_auto_20170209_0845.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0001_initial'), ] diff --git a/contest/migrations/0003_auto_20170217_0820.py b/contest/migrations/0003_auto_20170217_0820.py index 4776df901..467565db8 100644 --- a/contest/migrations/0003_auto_20170217_0820.py +++ b/contest/migrations/0003_auto_20170217_0820.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0002_auto_20170209_0845'), ] diff --git a/contest/migrations/0004_auto_20170717_1324.py b/contest/migrations/0004_auto_20170717_1324.py index 617790abb..8dc8bfe1a 100644 --- a/contest/migrations/0004_auto_20170717_1324.py +++ b/contest/migrations/0004_auto_20170717_1324.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0003_auto_20170217_0820'), ] diff --git a/contest/migrations/0005_auto_20170823_0918.py b/contest/migrations/0005_auto_20170823_0918.py index dbf12c66b..6b055be3a 100644 --- a/contest/migrations/0005_auto_20170823_0918.py +++ b/contest/migrations/0005_auto_20170823_0918.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0004_auto_20170717_1324'), ] diff --git a/contest/migrations/0006_auto_20171011_1214.py b/contest/migrations/0006_auto_20171011_1214.py index 0134a5bf3..4f73b257c 100644 --- a/contest/migrations/0006_auto_20171011_1214.py +++ b/contest/migrations/0006_auto_20171011_1214.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0005_auto_20170823_0918'), ] diff --git a/contest/migrations/0007_contestannouncement_visible.py b/contest/migrations/0007_contestannouncement_visible.py index 679874f74..5808eb980 100644 --- a/contest/migrations/0007_contestannouncement_visible.py +++ b/contest/migrations/0007_contestannouncement_visible.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0006_auto_20171011_1214'), ] diff --git a/contest/migrations/0008_contest_allowed_ip_ranges.py b/contest/migrations/0008_contest_allowed_ip_ranges.py index fd6c6ff71..357e89240 100644 --- a/contest/migrations/0008_contest_allowed_ip_ranges.py +++ b/contest/migrations/0008_contest_allowed_ip_ranges.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0007_contestannouncement_visible'), ] diff --git a/contest/migrations/0009_auto_20180501_0436.py b/contest/migrations/0009_auto_20180501_0436.py index eab2add97..d1b01b2d4 100644 --- a/contest/migrations/0009_auto_20180501_0436.py +++ b/contest/migrations/0009_auto_20180501_0436.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0008_contest_allowed_ip_ranges'), ] diff --git a/contest/migrations/0010_auto_20190326_0201.py b/contest/migrations/0010_auto_20190326_0201.py index 8700e1f16..f7b0fbdf1 100644 --- a/contest/migrations/0010_auto_20190326_0201.py +++ b/contest/migrations/0010_auto_20190326_0201.py @@ -5,7 +5,6 @@ class Migration(migrations.Migration): - dependencies = [ migrations.swappable_dependency(settings.AUTH_USER_MODEL), ('contest', '0009_auto_20180501_0436'), diff --git a/contest/models.py b/contest/models.py index 4616bef96..3698d5f87 100644 --- a/contest/models.py +++ b/contest/models.py @@ -45,14 +45,17 @@ def contest_type(self): # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等 def problem_details_permission(self, user): - return self.rule_type == ContestRuleType.ACM or \ - self.status == ContestStatus.CONTEST_ENDED or \ - user.is_authenticated and user.is_contest_admin(self) or \ - self.real_time_rank + return ( + self.rule_type == ContestRuleType.ACM + or self.status == ContestStatus.CONTEST_ENDED + or user.is_authenticated + and user.is_contest_admin(self) + or self.real_time_rank + ) class Meta: - db_table = "contest" - ordering = ("-start_time",) + db_table = 'contest' + ordering = ('-start_time',) class AbstractContestRank(models.Model): @@ -73,8 +76,8 @@ class ACMContestRank(AbstractContestRank): submission_info = JSONField(default=dict) class Meta: - db_table = "acm_contest_rank" - unique_together = (("user", "contest"),) + db_table = 'acm_contest_rank' + unique_together = (('user', 'contest'),) class OIContestRank(AbstractContestRank): @@ -84,8 +87,8 @@ class OIContestRank(AbstractContestRank): submission_info = JSONField(default=dict) class Meta: - db_table = "oi_contest_rank" - unique_together = (("user", "contest"),) + db_table = 'oi_contest_rank' + unique_together = (('user', 'contest'),) class ContestAnnouncement(models.Model): @@ -97,5 +100,5 @@ class ContestAnnouncement(models.Model): create_time = models.DateTimeField(auto_now_add=True) class Meta: - db_table = "contest_announcement" - ordering = ("-create_time",) + db_table = 'contest_announcement' + ordering = ('-create_time',) diff --git a/contest/serializers.py b/contest/serializers.py index 356cddaee..2b94b31c7 100644 --- a/contest/serializers.py +++ b/contest/serializers.py @@ -9,11 +9,15 @@ class CreateConetestSeriaizer(serializers.Serializer): description = serializers.CharField() start_time = serializers.DateTimeField() end_time = serializers.DateTimeField() - rule_type = serializers.ChoiceField(choices=[ContestRuleType.ACM, ContestRuleType.OI]) + rule_type = serializers.ChoiceField( + choices=[ContestRuleType.ACM, ContestRuleType.OI] + ) password = serializers.CharField(allow_blank=True, max_length=32) visible = serializers.BooleanField() real_time_rank = serializers.BooleanField() - allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=True) + allowed_ip_ranges = serializers.ListField( + child=serializers.CharField(max_length=32), allow_empty=True + ) class EditConetestSeriaizer(serializers.Serializer): @@ -25,7 +29,9 @@ class EditConetestSeriaizer(serializers.Serializer): password = serializers.CharField(allow_blank=True, allow_null=True, max_length=32) visible = serializers.BooleanField() real_time_rank = serializers.BooleanField() - allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32)) + allowed_ip_ranges = serializers.ListField( + child=serializers.CharField(max_length=32) + ) class ContestAdminSerializer(serializers.ModelSerializer): @@ -35,13 +41,13 @@ class ContestAdminSerializer(serializers.ModelSerializer): class Meta: model = Contest - fields = "__all__" + fields = '__all__' class ContestSerializer(ContestAdminSerializer): class Meta: model = Contest - exclude = ("password", "visible", "allowed_ip_ranges") + exclude = ('password', 'visible', 'allowed_ip_ranges') class ContestAnnouncementSerializer(serializers.ModelSerializer): @@ -49,7 +55,7 @@ class ContestAnnouncementSerializer(serializers.ModelSerializer): class Meta: model = ContestAnnouncement - fields = "__all__" + fields = '__all__' class CreateContestAnnouncementSerializer(serializers.Serializer): @@ -76,10 +82,10 @@ class ACMContestRankSerializer(serializers.ModelSerializer): class Meta: model = ACMContestRank - fields = "__all__" + fields = '__all__' def __init__(self, *args, **kwargs): - self.is_contest_admin = kwargs.pop("is_contest_admin", False) + self.is_contest_admin = kwargs.pop('is_contest_admin', False) super().__init__(*args, **kwargs) def get_user(self, obj): @@ -91,10 +97,10 @@ class OIContestRankSerializer(serializers.ModelSerializer): class Meta: model = OIContestRank - fields = "__all__" + fields = '__all__' def __init__(self, *args, **kwargs): - self.is_contest_admin = kwargs.pop("is_contest_admin", False) + self.is_contest_admin = kwargs.pop('is_contest_admin', False) super().__init__(*args, **kwargs) def get_user(self, obj): diff --git a/contest/tests.py b/contest/tests.py index a92e09ae5..2fe2608f1 100644 --- a/contest/tests.py +++ b/contest/tests.py @@ -7,19 +7,23 @@ from .models import ContestAnnouncement, ContestRuleType, Contest -DEFAULT_CONTEST_DATA = {"title": "test title", "description": "test description", - "start_time": timezone.localtime(timezone.now()), - "end_time": timezone.localtime(timezone.now()) + timedelta(days=1), - "rule_type": ContestRuleType.ACM, - "password": "123", - "allowed_ip_ranges": [], - "visible": True, "real_time_rank": True} +DEFAULT_CONTEST_DATA = { + 'title': 'test title', + 'description': 'test description', + 'start_time': timezone.localtime(timezone.now()), + 'end_time': timezone.localtime(timezone.now()) + timedelta(days=1), + 'rule_type': ContestRuleType.ACM, + 'password': '123', + 'allowed_ip_ranges': [], + 'visible': True, + 'real_time_rank': True, +} class ContestAdminAPITest(APITestCase): def setUp(self): self.create_super_admin() - self.url = self.reverse("contest_admin_api") + self.url = self.reverse('contest_admin_api') self.data = copy.deepcopy(DEFAULT_CONTEST_DATA) def test_create_contest(self): @@ -28,21 +32,25 @@ def test_create_contest(self): return response def test_create_contest_with_invalid_cidr(self): - self.data["allowed_ip_ranges"] = ["127.0.0"] + self.data['allowed_ip_ranges'] = ['127.0.0'] resp = self.client.post(self.url, data=self.data) - self.assertTrue(resp.data["data"].endswith("is not a valid cidr network")) + self.assertTrue(resp.data['data'].endswith('is not a valid cidr network')) def test_update_contest(self): - id = self.test_create_contest().data["data"]["id"] - update_data = {"id": id, "title": "update title", - "description": "update description", - "password": "12345", - "visible": False, "real_time_rank": False} + id = self.test_create_contest().data['data']['id'] + update_data = { + 'id': id, + 'title': 'update title', + 'description': 'update description', + 'password': '12345', + 'visible': False, + 'real_time_rank': False, + } data = copy.deepcopy(self.data) data.update(update_data) response = self.client.put(self.url, data=data) self.assertSuccess(response) - response_data = response.data["data"] + response_data = response.data['data'] for k in data.keys(): if isinstance(data[k], datetime): continue @@ -54,8 +62,8 @@ def test_get_contests(self): self.assertSuccess(response) def test_get_one_contest(self): - id = self.test_create_contest().data["data"]["id"] - response = self.client.get("{}?id={}".format(self.url, id)) + id = self.test_create_contest().data['data']['id'] + response = self.client.get('{}?id={}'.format(self.url, id)) self.assertSuccess(response) @@ -63,36 +71,51 @@ class ContestAPITest(APITestCase): def setUp(self): user = self.create_admin() self.contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA) - self.url = self.reverse("contest_api") + "?id=" + str(self.contest.id) + self.url = self.reverse('contest_api') + '?id=' + str(self.contest.id) def test_get_contest_list(self): - url = self.reverse("contest_list_api") - response = self.client.get(url + "?limit=10") + url = self.reverse('contest_list_api') + response = self.client.get(url + '?limit=10') self.assertSuccess(response) - self.assertEqual(len(response.data["data"]["results"]), 1) + self.assertEqual(len(response.data['data']['results']), 1) def test_get_one_contest(self): resp = self.client.get(self.url) self.assertSuccess(resp) def test_regular_user_validate_contest_password(self): - self.create_user("test", "test123") - url = self.reverse("contest_password_api") - resp = self.client.post(url, {"contest_id": self.contest.id, "password": "error_password"}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password or password expired"}) - - resp = self.client.post(url, {"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]}) + self.create_user('test', 'test123') + url = self.reverse('contest_password_api') + resp = self.client.post( + url, {'contest_id': self.contest.id, 'password': 'error_password'} + ) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Wrong password or password expired'} + ) + + resp = self.client.post( + url, + { + 'contest_id': self.contest.id, + 'password': DEFAULT_CONTEST_DATA['password'], + }, + ) self.assertSuccess(resp) def test_regular_user_access_contest(self): - self.create_user("test", "test123") - url = self.reverse("contest_access_api") - resp = self.client.get(url + "?contest_id=" + str(self.contest.id)) - self.assertFalse(resp.data["data"]["access"]) - - password_url = self.reverse("contest_password_api") - resp = self.client.post(password_url, - {"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]}) + self.create_user('test', 'test123') + url = self.reverse('contest_access_api') + resp = self.client.get(url + '?contest_id=' + str(self.contest.id)) + self.assertFalse(resp.data['data']['access']) + + password_url = self.reverse('contest_password_api') + resp = self.client.post( + password_url, + { + 'contest_id': self.contest.id, + 'password': DEFAULT_CONTEST_DATA['password'], + }, + ) self.assertSuccess(resp) resp = self.client.get(self.url) self.assertSuccess(resp) @@ -101,12 +124,17 @@ def test_regular_user_access_contest(self): class ContestAnnouncementAdminAPITest(APITestCase): def setUp(self): self.create_super_admin() - self.url = self.reverse("contest_announcement_admin_api") - contest_id = self.create_contest().data["data"]["id"] - self.data = {"title": "test title", "content": "test content", "contest_id": contest_id, "visible": True} + self.url = self.reverse('contest_announcement_admin_api') + contest_id = self.create_contest().data['data']['id'] + self.data = { + 'title': 'test title', + 'content': 'test content', + 'contest_id': contest_id, + 'visible': True, + } def create_contest(self): - url = self.reverse("contest_admin_api") + url = self.reverse('contest_admin_api') data = DEFAULT_CONTEST_DATA return self.client.post(url, data=data) @@ -116,47 +144,67 @@ def test_create_contest_announcement(self): return response def test_delete_contest_announcement(self): - id = self.test_create_contest_announcement().data["data"]["id"] - response = self.client.delete("{}?id={}".format(self.url, id)) + id = self.test_create_contest_announcement().data['data']['id'] + response = self.client.delete('{}?id={}'.format(self.url, id)) self.assertSuccess(response) self.assertFalse(ContestAnnouncement.objects.filter(id=id).exists()) def test_get_contest_announcements(self): self.test_create_contest_announcement() - response = self.client.get(self.url + "?contest_id=" + str(self.data["contest_id"])) + response = self.client.get( + self.url + '?contest_id=' + str(self.data['contest_id']) + ) self.assertSuccess(response) def test_get_one_contest_announcement(self): - id = self.test_create_contest_announcement().data["data"]["id"] - response = self.client.get("{}?id={}".format(self.url, id)) + id = self.test_create_contest_announcement().data['data']['id'] + response = self.client.get('{}?id={}'.format(self.url, id)) self.assertSuccess(response) class ContestAnnouncementListAPITest(APITestCase): def setUp(self): self.create_super_admin() - self.url = self.reverse("contest_announcement_api") + self.url = self.reverse('contest_announcement_api') def create_contest_announcements(self): - contest_id = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]["id"] - url = self.reverse("contest_announcement_admin_api") - self.client.post(url, data={"title": "test title1", "content": "test content1", "contest_id": contest_id}) - self.client.post(url, data={"title": "test title2", "content": "test content2", "contest_id": contest_id}) + contest_id = self.client.post( + self.reverse('contest_admin_api'), data=DEFAULT_CONTEST_DATA + ).data['data']['id'] + url = self.reverse('contest_announcement_admin_api') + self.client.post( + url, + data={ + 'title': 'test title1', + 'content': 'test content1', + 'contest_id': contest_id, + }, + ) + self.client.post( + url, + data={ + 'title': 'test title2', + 'content': 'test content2', + 'contest_id': contest_id, + }, + ) return contest_id def test_get_contest_announcement_list(self): contest_id = self.create_contest_announcements() - response = self.client.get(self.url, data={"contest_id": contest_id}) + response = self.client.get(self.url, data={'contest_id': contest_id}) self.assertSuccess(response) class ContestRankAPITest(APITestCase): def setUp(self): user = self.create_admin() - self.acm_contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA) - self.create_user("test", "test123") - self.url = self.reverse("contest_rank_api") + self.acm_contest = Contest.objects.create( + created_by=user, **DEFAULT_CONTEST_DATA + ) + self.create_user('test', 'test123') + self.url = self.reverse('contest_rank_api') def get_contest_rank(self): - resp = self.client.get(self.url + "?contest_id=" + self.acm_contest.id) + resp = self.client.get(self.url + '?contest_id=' + self.acm_contest.id) self.assertSuccess(resp) diff --git a/contest/urls/admin.py b/contest/urls/admin.py index 0b017cc80..bb4f53461 100644 --- a/contest/urls/admin.py +++ b/contest/urls/admin.py @@ -1,10 +1,25 @@ from django.conf.urls import url -from ..views.admin import ContestAnnouncementAPI, ContestAPI, ACMContestHelper, DownloadContestSubmissions +from ..views.admin import ( + ContestAnnouncementAPI, + ContestAPI, + ACMContestHelper, + DownloadContestSubmissions, +) urlpatterns = [ - url(r"^contest/?$", ContestAPI.as_view(), name="contest_admin_api"), - url(r"^contest/announcement/?$", ContestAnnouncementAPI.as_view(), name="contest_announcement_admin_api"), - url(r"^contest/acm_helper/?$", ACMContestHelper.as_view(), name="acm_contest_helper"), - url(r"^download_submissions/?$", DownloadContestSubmissions.as_view(), name="acm_contest_helper"), + url(r'^contest/?$', ContestAPI.as_view(), name='contest_admin_api'), + url( + r'^contest/announcement/?$', + ContestAnnouncementAPI.as_view(), + name='contest_announcement_admin_api', + ), + url( + r'^contest/acm_helper/?$', ACMContestHelper.as_view(), name='acm_contest_helper' + ), + url( + r'^download_submissions/?$', + DownloadContestSubmissions.as_view(), + name='acm_contest_helper', + ), ] diff --git a/contest/urls/oj.py b/contest/urls/oj.py index 9e94fa58f..c85122d64 100644 --- a/contest/urls/oj.py +++ b/contest/urls/oj.py @@ -6,10 +6,18 @@ from ..views.oj import ContestRankAPI urlpatterns = [ - url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"), - url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"), - url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"), - url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"), - url(r"^contest/access/?$", ContestAccessAPI.as_view(), name="contest_access_api"), - url(r"^contest_rank/?$", ContestRankAPI.as_view(), name="contest_rank_api"), + url(r'^contests/?$', ContestListAPI.as_view(), name='contest_list_api'), + url(r'^contest/?$', ContestAPI.as_view(), name='contest_api'), + url( + r'^contest/password/?$', + ContestPasswordVerifyAPI.as_view(), + name='contest_password_api', + ), + url( + r'^contest/announcement/?$', + ContestAnnouncementListAPI.as_view(), + name='contest_announcement_api', + ), + url(r'^contest/access/?$', ContestAccessAPI.as_view(), name='contest_access_api'), + url(r'^contest_rank/?$', ContestRankAPI.as_view(), name='contest_rank_api'), ] diff --git a/contest/views/admin.py b/contest/views/admin.py index 66addb158..ba2d58bb3 100644 --- a/contest/views/admin.py +++ b/contest/views/admin.py @@ -15,28 +15,33 @@ from utils.shortcuts import rand_str from utils.tasks import delete_files from ..models import Contest, ContestAnnouncement, ACMContestRank -from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer, - CreateConetestSeriaizer, CreateContestAnnouncementSerializer, - EditConetestSeriaizer, EditContestAnnouncementSerializer, - ACMContesHelperSerializer, ) +from ..serializers import ( + ContestAnnouncementSerializer, + ContestAdminSerializer, + CreateConetestSeriaizer, + CreateContestAnnouncementSerializer, + EditConetestSeriaizer, + EditContestAnnouncementSerializer, + ACMContesHelperSerializer, +) class ContestAPI(APIView): @validate_serializer(CreateConetestSeriaizer) def post(self, request): data = request.data - data["start_time"] = dateutil.parser.parse(data["start_time"]) - data["end_time"] = dateutil.parser.parse(data["end_time"]) - data["created_by"] = request.user - if data["end_time"] <= data["start_time"]: - return self.error("Start time must occur earlier than end time") - if data.get("password") and data["password"] == "": - data["password"] = None - for ip_range in data["allowed_ip_ranges"]: + data['start_time'] = dateutil.parser.parse(data['start_time']) + data['end_time'] = dateutil.parser.parse(data['end_time']) + data['created_by'] = request.user + if data['end_time'] <= data['start_time']: + return self.error('Start time must occur earlier than end time') + if data.get('password') and data['password'] == '': + data['password'] = None + for ip_range in data['allowed_ip_ranges']: try: ip_network(ip_range, strict=False) except ValueError: - return self.error(f"{ip_range} is not a valid cidr network") + return self.error(f'{ip_range} is not a valid cidr network') contest = Contest.objects.create(**data) return self.success(ContestAdminSerializer(contest).data) @@ -44,23 +49,23 @@ def post(self, request): def put(self, request): data = request.data try: - contest = Contest.objects.get(id=data.pop("id")) + contest = Contest.objects.get(id=data.pop('id')) ensure_created_by(contest, request.user) except Contest.DoesNotExist: - return self.error("Contest does not exist") - data["start_time"] = dateutil.parser.parse(data["start_time"]) - data["end_time"] = dateutil.parser.parse(data["end_time"]) - if data["end_time"] <= data["start_time"]: - return self.error("Start time must occur earlier than end time") - if not data["password"]: - data["password"] = None - for ip_range in data["allowed_ip_ranges"]: + return self.error('Contest does not exist') + data['start_time'] = dateutil.parser.parse(data['start_time']) + data['end_time'] = dateutil.parser.parse(data['end_time']) + if data['end_time'] <= data['start_time']: + return self.error('Start time must occur earlier than end time') + if not data['password']: + data['password'] = None + for ip_range in data['allowed_ip_ranges']: try: ip_network(ip_range, strict=False) except ValueError: - return self.error(f"{ip_range} is not a valid cidr network") - if not contest.real_time_rank and data.get("real_time_rank"): - cache_key = f"{CacheKey.contest_rank_cache}:{contest.id}" + return self.error(f'{ip_range} is not a valid cidr network') + if not contest.real_time_rank and data.get('real_time_rank'): + cache_key = f'{CacheKey.contest_rank_cache}:{contest.id}' cache.delete(cache_key) for k, v in data.items(): @@ -69,23 +74,25 @@ def put(self, request): return self.success(ContestAdminSerializer(contest).data) def get(self, request): - contest_id = request.GET.get("id") + contest_id = request.GET.get('id') if contest_id: try: contest = Contest.objects.get(id=contest_id) ensure_created_by(contest, request.user) return self.success(ContestAdminSerializer(contest).data) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - contests = Contest.objects.all().order_by("-create_time") + contests = Contest.objects.all().order_by('-create_time') if request.user.is_admin(): contests = contests.filter(created_by=request.user) - keyword = request.GET.get("keyword") + keyword = request.GET.get('keyword') if keyword: contests = contests.filter(title__contains=keyword) - return self.success(self.paginate_data(request, contests, ContestAdminSerializer)) + return self.success( + self.paginate_data(request, contests, ContestAdminSerializer) + ) class ContestAnnouncementAPI(APIView): @@ -96,12 +103,12 @@ def post(self, request): """ data = request.data try: - contest = Contest.objects.get(id=data.pop("contest_id")) + contest = Contest.objects.get(id=data.pop('contest_id')) ensure_created_by(contest, request.user) - data["contest"] = contest - data["created_by"] = request.user + data['contest'] = contest + data['created_by'] = request.user except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') announcement = ContestAnnouncement.objects.create(**data) return self.success(ContestAnnouncementSerializer(announcement).data) @@ -112,10 +119,10 @@ def put(self, request): """ data = request.data try: - contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id")) + contest_announcement = ContestAnnouncement.objects.get(id=data.pop('id')) ensure_created_by(contest_announcement, request.user) except ContestAnnouncement.DoesNotExist: - return self.error("Contest announcement does not exist") + return self.error('Contest announcement does not exist') for k, v in data.items(): setattr(contest_announcement, k, v) contest_announcement.save() @@ -125,11 +132,12 @@ def delete(self, request): """ Delete one contest_announcement. """ - contest_announcement_id = request.GET.get("id") + contest_announcement_id = request.GET.get('id') if contest_announcement_id: if request.user.is_admin(): - ContestAnnouncement.objects.filter(id=contest_announcement_id, - contest__created_by=request.user).delete() + ContestAnnouncement.objects.filter( + id=contest_announcement_id, contest__created_by=request.user + ).delete() else: ContestAnnouncement.objects.filter(id=contest_announcement_id).delete() return self.success() @@ -138,73 +146,92 @@ def get(self, request): """ Get one contest_announcement or contest_announcement list. """ - contest_announcement_id = request.GET.get("id") + contest_announcement_id = request.GET.get('id') if contest_announcement_id: try: - contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id) + contest_announcement = ContestAnnouncement.objects.get( + id=contest_announcement_id + ) ensure_created_by(contest_announcement, request.user) - return self.success(ContestAnnouncementSerializer(contest_announcement).data) + return self.success( + ContestAnnouncementSerializer(contest_announcement).data + ) except ContestAnnouncement.DoesNotExist: - return self.error("Contest announcement does not exist") + return self.error('Contest announcement does not exist') - contest_id = request.GET.get("contest_id") + contest_id = request.GET.get('contest_id') if not contest_id: - return self.error("Parameter error") - contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id) + return self.error('Parameter error') + contest_announcements = ContestAnnouncement.objects.filter( + contest_id=contest_id + ) if request.user.is_admin(): - contest_announcements = contest_announcements.filter(created_by=request.user) - keyword = request.GET.get("keyword") + contest_announcements = contest_announcements.filter( + created_by=request.user + ) + keyword = request.GET.get('keyword') if keyword: - contest_announcements = contest_announcements.filter(title__contains=keyword) - return self.success(ContestAnnouncementSerializer(contest_announcements, many=True).data) + contest_announcements = contest_announcements.filter( + title__contains=keyword + ) + return self.success( + ContestAnnouncementSerializer(contest_announcements, many=True).data + ) class ACMContestHelper(APIView): - @check_contest_permission(check_type="ranks") + @check_contest_permission(check_type='ranks') def get(self, request): - ranks = ACMContestRank.objects.filter(contest=self.contest, accepted_number__gt=0) \ - .values("id", "user__username", "user__userprofile__real_name", "submission_info") + ranks = ACMContestRank.objects.filter( + contest=self.contest, accepted_number__gt=0 + ).values( + 'id', 'user__username', 'user__userprofile__real_name', 'submission_info' + ) results = [] for rank in ranks: - for problem_id, info in rank["submission_info"].items(): - if info["is_ac"]: - results.append({ - "id": rank["id"], - "username": rank["user__username"], - "real_name": rank["user__userprofile__real_name"], - "problem_id": problem_id, - "ac_info": info, - "checked": info.get("checked", False) - }) - results.sort(key=lambda x: -x["ac_info"]["ac_time"]) + for problem_id, info in rank['submission_info'].items(): + if info['is_ac']: + results.append( + { + 'id': rank['id'], + 'username': rank['user__username'], + 'real_name': rank['user__userprofile__real_name'], + 'problem_id': problem_id, + 'ac_info': info, + 'checked': info.get('checked', False), + } + ) + results.sort(key=lambda x: -x['ac_info']['ac_time']) return self.success(results) - @check_contest_permission(check_type="ranks") + @check_contest_permission(check_type='ranks') @validate_serializer(ACMContesHelperSerializer) def put(self, request): data = request.data try: - rank = ACMContestRank.objects.get(pk=data["rank_id"]) + rank = ACMContestRank.objects.get(pk=data['rank_id']) except ACMContestRank.DoesNotExist: - return self.error("Rank id does not exist") - problem_rank_status = rank.submission_info.get(data["problem_id"]) + return self.error('Rank id does not exist') + problem_rank_status = rank.submission_info.get(data['problem_id']) if not problem_rank_status: - return self.error("Problem id does not exist") - problem_rank_status["checked"] = data["checked"] - rank.save(update_fields=("submission_info",)) + return self.error('Problem id does not exist') + problem_rank_status['checked'] = data['checked'] + rank.save(update_fields=('submission_info',)) return self.success() class DownloadContestSubmissions(APIView): def _dump_submissions(self, contest, exclude_admin=True): - problem_ids = contest.problem_set.all().values_list("id", "_id") + problem_ids = contest.problem_set.all().values_list('id', '_id') id2display_id = {k[0]: k[1] for k in problem_ids} ac_map = {k[0]: False for k in problem_ids} - submissions = Submission.objects.filter(contest=contest, result=JudgeStatus.ACCEPTED).order_by("-create_time") - user_ids = submissions.values_list("user_id", flat=True) + submissions = Submission.objects.filter( + contest=contest, result=JudgeStatus.ACCEPTED + ).order_by('-create_time') + user_ids = submissions.values_list('user_id', flat=True) users = User.objects.filter(id__in=user_ids) - path = f"/tmp/{rand_str()}.zip" - with zipfile.ZipFile(path, "w") as zip_file: + path = f'/tmp/{rand_str()}.zip' + with zipfile.ZipFile(path, 'w') as zip_file: for user in users: if user.is_admin_role() and exclude_admin: continue @@ -214,28 +241,34 @@ def _dump_submissions(self, contest, exclude_admin=True): problem_id = submission.problem_id if user_ac_map[problem_id]: continue - file_name = f"{user.username}_{id2display_id[submission.problem_id]}.txt" + file_name = ( + f'{user.username}_{id2display_id[submission.problem_id]}.txt' + ) compression = zipfile.ZIP_DEFLATED - zip_file.writestr(zinfo_or_arcname=f"{file_name}", - data=submission.code, - compress_type=compression) + zip_file.writestr( + zinfo_or_arcname=f'{file_name}', + data=submission.code, + compress_type=compression, + ) user_ac_map[problem_id] = True return path def get(self, request): - contest_id = request.GET.get("contest_id") + contest_id = request.GET.get('contest_id') if not contest_id: - return self.error("Parameter error") + return self.error('Parameter error') try: contest = Contest.objects.get(id=contest_id) ensure_created_by(contest, request.user) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - exclude_admin = request.GET.get("exclude_admin") == "1" + exclude_admin = request.GET.get('exclude_admin') == '1' zip_path = self._dump_submissions(contest, exclude_admin) delete_files.send_with_options(args=(zip_path,), delay=300_000) - resp = FileResponse(open(zip_path, "rb")) - resp["Content-Type"] = "application/zip" - resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}" + resp = FileResponse(open(zip_path, 'rb')) + resp['Content-Type'] = 'application/zip' + resp[ + 'Content-Disposition' + ] = f'attachment;filename={os.path.basename(zip_path)}' return resp diff --git a/contest/views/oj.py b/contest/views/oj.py index 4164ec332..4f9b307fc 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -10,7 +10,11 @@ from utils.constants import CacheKey, CONTEST_PASSWORD_SESSION_KEY from utils.shortcuts import datetime2str, check_is_id from account.models import AdminType -from account.decorators import login_required, check_contest_permission, check_contest_password +from account.decorators import ( + login_required, + check_contest_permission, + check_contest_password, +) from utils.constants import ContestRuleType, ContestStatus from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank @@ -20,13 +24,15 @@ class ContestAnnouncementListAPI(APIView): - @check_contest_permission(check_type="announcements") + @check_contest_permission(check_type='announcements') def get(self, request): - contest_id = request.GET.get("contest_id") + contest_id = request.GET.get('contest_id') if not contest_id: - return self.error("Invalid parameter, contest_id is required") - data = ContestAnnouncement.objects.select_related("created_by").filter(contest_id=contest_id, visible=True) - max_id = request.GET.get("max_id") + return self.error('Invalid parameter, contest_id is required') + data = ContestAnnouncement.objects.select_related('created_by').filter( + contest_id=contest_id, visible=True + ) + max_id = request.GET.get('max_id') if max_id: data = data.filter(id__gt=max_id) return self.success(ContestAnnouncementSerializer(data, many=True).data) @@ -34,24 +40,24 @@ def get(self, request): class ContestAPI(APIView): def get(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id or not check_is_id(id): - return self.error("Invalid parameter, id is required") + return self.error('Invalid parameter, id is required') try: contest = Contest.objects.get(id=id, visible=True) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') data = ContestSerializer(contest).data - data["now"] = datetime2str(now()) + data['now'] = datetime2str(now()) return self.success(data) class ContestListAPI(APIView): def get(self, request): - contests = Contest.objects.select_related("created_by").filter(visible=True) - keyword = request.GET.get("keyword") - rule_type = request.GET.get("rule_type") - status = request.GET.get("status") + contests = Contest.objects.select_related('created_by').filter(visible=True) + keyword = request.GET.get('keyword') + rule_type = request.GET.get('rule_type') + status = request.GET.get('status') if keyword: contests = contests.filter(title__contains=keyword) if rule_type: @@ -73,16 +79,18 @@ class ContestPasswordVerifyAPI(APIView): def post(self, request): data = request.data try: - contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False) + contest = Contest.objects.get( + id=data['contest_id'], visible=True, password__isnull=False + ) except Contest.DoesNotExist: - return self.error("Contest does not exist") - if not check_contest_password(data["password"], contest.password): - return self.error("Wrong password or password expired") + return self.error('Contest does not exist') + if not check_contest_password(data['password'], contest.password): + return self.error('Wrong password or password expired') # password verify OK. if CONTEST_PASSWORD_SESSION_KEY not in request.session: request.session[CONTEST_PASSWORD_SESSION_KEY] = {} - request.session[CONTEST_PASSWORD_SESSION_KEY][contest.id] = data["password"] + request.session[CONTEST_PASSWORD_SESSION_KEY][contest.id] = data['password'] # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved request.session.modified = True return self.success(True) @@ -91,51 +99,70 @@ def post(self, request): class ContestAccessAPI(APIView): @login_required def get(self, request): - contest_id = request.GET.get("contest_id") + contest_id = request.GET.get('contest_id') if not contest_id: return self.error() try: - contest = Contest.objects.get(id=contest_id, visible=True, password__isnull=False) + contest = Contest.objects.get( + id=contest_id, visible=True, password__isnull=False + ) except Contest.DoesNotExist: - return self.error("Contest does not exist") - session_pass = request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get(contest.id) - return self.success({"access": check_contest_password(session_pass, contest.password)}) + return self.error('Contest does not exist') + session_pass = request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get( + contest.id + ) + return self.success( + {'access': check_contest_password(session_pass, contest.password)} + ) class ContestRankAPI(APIView): def get_rank(self): if self.contest.rule_type == ContestRuleType.ACM: - return ACMContestRank.objects.filter(contest=self.contest, - user__admin_type=AdminType.REGULAR_USER, - user__is_disabled=False).\ - select_related("user").order_by("-accepted_number", "total_time") + return ( + ACMContestRank.objects.filter( + contest=self.contest, + user__admin_type=AdminType.REGULAR_USER, + user__is_disabled=False, + ) + .select_related('user') + .order_by('-accepted_number', 'total_time') + ) else: - return OIContestRank.objects.filter(contest=self.contest, - user__admin_type=AdminType.REGULAR_USER, - user__is_disabled=False). \ - select_related("user").order_by("-total_score") + return ( + OIContestRank.objects.filter( + contest=self.contest, + user__admin_type=AdminType.REGULAR_USER, + user__is_disabled=False, + ) + .select_related('user') + .order_by('-total_score') + ) def column_string(self, n): - string = "" + string = '' while n > 0: n, remainder = divmod(n - 1, 26) string = chr(65 + remainder) + string return string - @check_contest_permission(check_type="ranks") + @check_contest_permission(check_type='ranks') def get(self, request): - download_csv = request.GET.get("download_csv") - force_refresh = request.GET.get("force_refresh") - is_contest_admin = request.user.is_authenticated and request.user.is_contest_admin(self.contest) + download_csv = request.GET.get('download_csv') + force_refresh = request.GET.get('force_refresh') + is_contest_admin = ( + request.user.is_authenticated + and request.user.is_contest_admin(self.contest) + ) if self.contest.rule_type == ContestRuleType.OI: serializer = OIContestRankSerializer else: serializer = ACMContestRankSerializer - if force_refresh == "1" and is_contest_admin: + if force_refresh == '1' and is_contest_admin: qs = self.get_rank() else: - cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}" + cache_key = f'{CacheKey.contest_rank_cache}:{self.contest.id}' qs = cache.get(cache_key) if not qs: qs = self.get_rank() @@ -143,50 +170,70 @@ def get(self, request): if download_csv: data = serializer(qs, many=True, is_contest_admin=is_contest_admin).data - contest_problems = Problem.objects.filter(contest=self.contest, visible=True).order_by("_id") + contest_problems = Problem.objects.filter( + contest=self.contest, visible=True + ).order_by('_id') problem_ids = [item.id for item in contest_problems] f = io.BytesIO() workbook = xlsxwriter.Workbook(f) worksheet = workbook.add_worksheet() - worksheet.write("A1", "User ID") - worksheet.write("B1", "Username") - worksheet.write("C1", "Real Name") + worksheet.write('A1', 'User ID') + worksheet.write('B1', 'Username') + worksheet.write('C1', 'Real Name') if self.contest.rule_type == ContestRuleType.OI: - worksheet.write("D1", "Total Score") + worksheet.write('D1', 'Total Score') for item in range(contest_problems.count()): - worksheet.write(self.column_string(5 + item) + "1", f"{contest_problems[item].title}") + worksheet.write( + self.column_string(5 + item) + '1', + f'{contest_problems[item].title}', + ) for index, item in enumerate(data): - worksheet.write_string(index + 1, 0, str(item["user"]["id"])) - worksheet.write_string(index + 1, 1, item["user"]["username"]) - worksheet.write_string(index + 1, 2, item["user"]["real_name"] or "") - worksheet.write_string(index + 1, 3, str(item["total_score"])) - for k, v in item["submission_info"].items(): - worksheet.write_string(index + 1, 4 + problem_ids.index(int(k)), str(v)) + worksheet.write_string(index + 1, 0, str(item['user']['id'])) + worksheet.write_string(index + 1, 1, item['user']['username']) + worksheet.write_string( + index + 1, 2, item['user']['real_name'] or '' + ) + worksheet.write_string(index + 1, 3, str(item['total_score'])) + for k, v in item['submission_info'].items(): + worksheet.write_string( + index + 1, 4 + problem_ids.index(int(k)), str(v) + ) else: - worksheet.write("D1", "AC") - worksheet.write("E1", "Total Submission") - worksheet.write("F1", "Total Time") + worksheet.write('D1', 'AC') + worksheet.write('E1', 'Total Submission') + worksheet.write('F1', 'Total Time') for item in range(contest_problems.count()): - worksheet.write(self.column_string(7 + item) + "1", f"{contest_problems[item].title}") + worksheet.write( + self.column_string(7 + item) + '1', + f'{contest_problems[item].title}', + ) for index, item in enumerate(data): - worksheet.write_string(index + 1, 0, str(item["user"]["id"])) - worksheet.write_string(index + 1, 1, item["user"]["username"]) - worksheet.write_string(index + 1, 2, item["user"]["real_name"] or "") - worksheet.write_string(index + 1, 3, str(item["accepted_number"])) - worksheet.write_string(index + 1, 4, str(item["submission_number"])) - worksheet.write_string(index + 1, 5, str(item["total_time"])) - for k, v in item["submission_info"].items(): - worksheet.write_string(index + 1, 6 + problem_ids.index(int(k)), str(v["is_ac"])) + worksheet.write_string(index + 1, 0, str(item['user']['id'])) + worksheet.write_string(index + 1, 1, item['user']['username']) + worksheet.write_string( + index + 1, 2, item['user']['real_name'] or '' + ) + worksheet.write_string(index + 1, 3, str(item['accepted_number'])) + worksheet.write_string(index + 1, 4, str(item['submission_number'])) + worksheet.write_string(index + 1, 5, str(item['total_time'])) + for k, v in item['submission_info'].items(): + worksheet.write_string( + index + 1, 6 + problem_ids.index(int(k)), str(v['is_ac']) + ) workbook.close() f.seek(0) response = HttpResponse(f.read()) - response["Content-Disposition"] = f"attachment; filename=content-{self.contest.id}-rank.xlsx" - response["Content-Type"] = "application/xlsx" + response[ + 'Content-Disposition' + ] = f'attachment; filename=content-{self.contest.id}-rank.xlsx' + response['Content-Type'] = 'application/xlsx' return response page_qs = self.paginate_data(request, qs) - page_qs["results"] = serializer(page_qs["results"], many=True, is_contest_admin=is_contest_admin).data + page_qs['results'] = serializer( + page_qs['results'], many=True, is_contest_admin=is_contest_admin + ).data return self.success(page_qs) diff --git a/fps/parser.py b/fps/parser.py index 63d3559e5..55ffa94d4 100644 --- a/fps/parser.py +++ b/fps/parser.py @@ -16,9 +16,11 @@ def __init__(self, fps_path=None, string_data=None): elif string_data: self._ertree = ET.fromstring(string_data).getroot() else: - raise ValueError("You must tell me the file path or directly give me the data for the file") - version = self._etree.attrib.get("version", "No Version") - if version not in ["1.1", "1.2"]: + raise ValueError( + 'You must tell me the file path or directly give me the data for the file' + ) + version = self._etree.attrib.get('version', 'No Version') + if version not in ['1.1', '1.2']: raise ValueError("Unsupported version '" + version + "'") @property @@ -28,79 +30,89 @@ def etree(self): def parse(self): ret = [] for node in self._etree: - if node.tag == "item": + if node.tag == 'item': ret.append(self._parse_one_problem(node)) return ret def _parse_one_problem(self, node): sample_start = True test_case_start = True - problem = {"title": "No Title", "description": "No Description", - "input": "No Input Description", - "output": "No Output Description", - "memory_limit": {"unit": None, "value": None}, - "time_limit": {"unit": None, "value": None}, - "samples": [], "images": [], "append": [], - "template": [], "prepend": [], "test_cases": [], - "hint": None, "source": None, "spj": None, "solution": []} + problem = { + 'title': 'No Title', + 'description': 'No Description', + 'input': 'No Input Description', + 'output': 'No Output Description', + 'memory_limit': {'unit': None, 'value': None}, + 'time_limit': {'unit': None, 'value': None}, + 'samples': [], + 'images': [], + 'append': [], + 'template': [], + 'prepend': [], + 'test_cases': [], + 'hint': None, + 'source': None, + 'spj': None, + 'solution': [], + } for item in node: tag = item.tag - if tag in ["title", "description", "input", "output", "hint", "source"]: + if tag in ['title', 'description', 'input', 'output', 'hint', 'source']: problem[item.tag] = item.text - elif tag == "time_limit": - unit = item.attrib.get("unit", "s") - if unit not in ["s", "ms"]: - raise ValueError("Invalid time limit unit") - problem["time_limit"]["unit"] = item.attrib.get("unit", "s") + elif tag == 'time_limit': + unit = item.attrib.get('unit', 's') + if unit not in ['s', 'ms']: + raise ValueError('Invalid time limit unit') + problem['time_limit']['unit'] = item.attrib.get('unit', 's') value = int(item.text) if value <= 0: - raise ValueError("Invalid time limit value") - problem["time_limit"]["value"] = value - elif tag == "memory_limit": - unit = item.attrib.get("unit", "MB") - if unit not in ["MB", "KB", "mb", "kb"]: - raise ValueError("Invalid memory limit unit") - problem["memory_limit"]["unit"] = unit.upper() + raise ValueError('Invalid time limit value') + problem['time_limit']['value'] = value + elif tag == 'memory_limit': + unit = item.attrib.get('unit', 'MB') + if unit not in ['MB', 'KB', 'mb', 'kb']: + raise ValueError('Invalid memory limit unit') + problem['memory_limit']['unit'] = unit.upper() value = int(item.text) if value <= 0: - raise ValueError("Invalid memory limit value") - problem["memory_limit"]["value"] = value - elif tag in ["template", "append", "prepend", "solution"]: - lang = item.attrib.get("language") + raise ValueError('Invalid memory limit value') + problem['memory_limit']['value'] = value + elif tag in ['template', 'append', 'prepend', 'solution']: + lang = item.attrib.get('language') if not lang: - raise ValueError("Invalid " + tag + ", language name is missed") - problem[tag].append({"language": lang, "code": item.text}) - elif tag == "spj": - lang = item.attrib.get("language") + raise ValueError('Invalid ' + tag + ', language name is missed') + problem[tag].append({'language': lang, 'code': item.text}) + elif tag == 'spj': + lang = item.attrib.get('language') if not lang: - raise ValueError("Invalid spj, language name if missed") - problem["spj"] = {"language": lang, "code": item.text} - elif tag == "img": - problem["images"].append({"src": None, "blob": None}) + raise ValueError('Invalid spj, language name if missed') + problem['spj'] = {'language': lang, 'code': item.text} + elif tag == 'img': + problem['images'].append({'src': None, 'blob': None}) for child in item: - if child.tag == "src": - problem["images"][-1]["src"] = child.text - elif child.tag == "base64": - problem["images"][-1]["blob"] = base64.b64decode(child.text) - elif tag == "sample_input": + if child.tag == 'src': + problem['images'][-1]['src'] = child.text + elif child.tag == 'base64': + problem['images'][-1]['blob'] = base64.b64decode(child.text) + elif tag == 'sample_input': if not sample_start: raise ValueError("Invalid xml, error 'sample_input' tag order") - problem["samples"].append({"input": item.text, "output": None}) + problem['samples'].append({'input': item.text, 'output': None}) sample_start = False - elif tag == "sample_output": + elif tag == 'sample_output': if sample_start: raise ValueError("Invalid xml, error 'sample_output' tag order") - problem["samples"][-1]["output"] = item.text + problem['samples'][-1]['output'] = item.text sample_start = True - elif tag == "test_input": + elif tag == 'test_input': if not test_case_start: raise ValueError("Invalid xml, error 'test_input' tag order") - problem["test_cases"].append({"input": item.text, "output": None}) + problem['test_cases'].append({'input': item.text, 'output': None}) test_case_start = False - elif tag == "test_output": + elif tag == 'test_output': if test_case_start: raise ValueError("Invalid xml, error 'test_output' tag order") - problem["test_cases"][-1]["output"] = item.text + problem['test_cases'][-1]['output'] = item.text test_case_start = True return problem @@ -109,14 +121,18 @@ def _parse_one_problem(self, node): class FPSHelper(object): def save_image(self, problem, base_dir, base_url): _problem = copy.deepcopy(problem) - for img in _problem["images"]: - name = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(12)) - ext = os.path.splitext(img["src"])[1] + for img in _problem['images']: + name = ''.join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(12) + ) + ext = os.path.splitext(img['src'])[1] file_name = name + ext - with open(os.path.join(base_dir, file_name), "wb") as f: - f.write(img["blob"]) - for item in ["description", "input", "output"]: - _problem[item] = _problem[item].replace(img["src"], os.path.join(base_url, file_name)) + with open(os.path.join(base_dir, file_name), 'wb') as f: + f.write(img['blob']) + for item in ['description', 'input', 'output']: + _problem[item] = _problem[item].replace( + img['src'], os.path.join(base_url, file_name) + ) return _problem # { @@ -132,49 +148,56 @@ def save_image(self, problem, base_dir, base_url): # } # } def save_test_case(self, problem, base_dir): - spj = problem.get("spj", {}) + spj = problem.get('spj', {}) test_cases = {} - for index, item in enumerate(problem["test_cases"]): - input_content = item.get("input") - output_content = item.get("output") + for index, item in enumerate(problem['test_cases']): + input_content = item.get('input') + output_content = item.get('output') if input_content: - with open(os.path.join(base_dir, str(index + 1) + ".in"), "w", encoding="utf-8") as f: + with open( + os.path.join(base_dir, str(index + 1) + '.in'), + 'w', + encoding='utf-8', + ) as f: f.write(input_content) if output_content: - with open(os.path.join(base_dir, str(index + 1) + ".out"), "w", encoding="utf-8") as f: + with open( + os.path.join(base_dir, str(index + 1) + '.out'), + 'w', + encoding='utf-8', + ) as f: f.write(output_content) if spj: one_info = { - "input_size": len(input_content), - "input_name": f"{index + 1}.in" + 'input_size': len(input_content), + 'input_name': f'{index + 1}.in', } else: one_info = { - "input_size": len(input_content), - "input_name": f"{index + 1}.in", - "output_size": len(output_content), - "output_name": f"{index + 1}.out", - "stripped_output_md5": hashlib.md5(output_content.rstrip().encode("utf-8")).hexdigest() + 'input_size': len(input_content), + 'input_name': f'{index + 1}.in', + 'output_size': len(output_content), + 'output_name': f'{index + 1}.out', + 'stripped_output_md5': hashlib.md5( + output_content.rstrip().encode('utf-8') + ).hexdigest(), } test_cases[index] = one_info - info = { - "spj": True if spj else False, - "test_cases": test_cases - } - with open(os.path.join(base_dir, "info"), "w", encoding="utf-8") as f: + info = {'spj': True if spj else False, 'test_cases': test_cases} + with open(os.path.join(base_dir, 'info'), 'w', encoding='utf-8') as f: f.write(json.dumps(info, indent=4)) return info -if __name__ == "__main__": +if __name__ == '__main__': import pprint - parser = FPSParser("fps.xml") + parser = FPSParser('fps.xml') helper = FPSHelper() problems = parser.parse() for index, problem in enumerate(problems): - path = os.path.join("/tmp/", str(index + 1)) + path = os.path.join('/tmp/', str(index + 1)) os.mkdir(path) helper.save_test_case(problem, path) - pprint.pprint(helper.save_image(problem, "/tmp", "/static/img")) + pprint.pprint(helper.save_image(problem, '/tmp', '/static/img')) diff --git a/judge/dispatcher.py b/judge/dispatcher.py index bd099ac2e..da2d96866 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -25,9 +25,10 @@ def process_pending_task(): if cache.llen(CacheKey.waiting_queue): # 防止循环引入 from judge.tasks import judge_task + tmp_data = cache.rpop(CacheKey.waiting_queue) if tmp_data: - data = json.loads(tmp_data.decode("utf-8")) + data = json.loads(tmp_data.decode('utf-8')) judge_task.send(**data) @@ -37,29 +38,37 @@ def __init__(self): def __enter__(self) -> [JudgeServer, None]: with transaction.atomic(): - servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number") - servers = [s for s in servers if s.status == "normal"] + servers = ( + JudgeServer.objects.select_for_update() + .filter(is_disabled=False) + .order_by('task_number') + ) + servers = [s for s in servers if s.status == 'normal'] for server in servers: if server.task_number <= server.cpu_core * 2: - server.task_number = F("task_number") + 1 - server.save(update_fields=["task_number"]) + server.task_number = F('task_number') + 1 + server.save(update_fields=['task_number']) self.server = server return server return None def __exit__(self, exc_type, exc_val, exc_tb): if self.server: - JudgeServer.objects.filter(id=self.server.id).update(task_number=F("task_number") - 1) + JudgeServer.objects.filter(id=self.server.id).update( + task_number=F('task_number') - 1 + ) class DispatcherBase(object): def __init__(self): - self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() + self.token = hashlib.sha256( + SysOptions.judge_server_token.encode('utf-8') + ).hexdigest() def _request(self, url, data=None): - kwargs = {"headers": {"X-Judge-Server-Token": self.token}} + kwargs = {'headers': {'X-Judge-Server-Token': self.token}} if data: - kwargs["json"] = data + kwargs['json'] = data try: return requests.post(url, **kwargs).json() except Exception as e: @@ -69,23 +78,28 @@ def _request(self, url, data=None): class SPJCompiler(DispatcherBase): def __init__(self, spj_code, spj_version, spj_language): super().__init__() - spj_compile_config = list(filter(lambda config: spj_language == config["name"], SysOptions.spj_languages))[0]["spj"][ - "compile"] + spj_compile_config = list( + filter( + lambda config: spj_language == config['name'], SysOptions.spj_languages + ) + )[0]['spj']['compile'] self.data = { - "src": spj_code, - "spj_version": spj_version, - "spj_compile_config": spj_compile_config + 'src': spj_code, + 'spj_version': spj_version, + 'spj_compile_config': spj_compile_config, } def compile_spj(self): with ChooseJudgeServer() as server: if not server: - return "No available judge_server" - result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data) + return 'No available judge_server' + result = self._request( + urljoin(server.service_url, 'compile_spj'), data=self.data + ) if not result: - return "Failed to call judge server" - if result["err"]: - return result["data"] + return 'Failed to call judge server' + if result['err']: + return result['data'] class JudgeDispatcher(DispatcherBase): @@ -96,98 +110,129 @@ def __init__(self, submission_id, problem_id): self.last_result = self.submission.result if self.submission.info else None if self.contest_id: - self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id) + self.problem = Problem.objects.select_related('contest').get( + id=problem_id, contest_id=self.contest_id + ) self.contest = self.problem.contest else: self.problem = Problem.objects.get(id=problem_id) def _compute_statistic_info(self, resp_data): # 用时和内存占用保存为多个测试点中最长的那个 - self.submission.statistic_info["time_cost"] = max([x["cpu_time"] for x in resp_data]) - self.submission.statistic_info["memory_cost"] = max([x["memory"] for x in resp_data]) + self.submission.statistic_info['time_cost'] = max( + [x['cpu_time'] for x in resp_data] + ) + self.submission.statistic_info['memory_cost'] = max( + [x['memory'] for x in resp_data] + ) # sum up the score in OI mode if self.problem.rule_type == ProblemRuleType.OI: score = 0 try: for i in range(len(resp_data)): - if resp_data[i]["result"] == JudgeStatus.ACCEPTED: - resp_data[i]["score"] = self.problem.test_case_score[i]["score"] - score += resp_data[i]["score"] + if resp_data[i]['result'] == JudgeStatus.ACCEPTED: + resp_data[i]['score'] = self.problem.test_case_score[i]['score'] + score += resp_data[i]['score'] else: - resp_data[i]["score"] = 0 + resp_data[i]['score'] = 0 except IndexError: - logger.error(f"Index Error raised when summing up the score in problem {self.problem.id}") - self.submission.statistic_info["score"] = 0 + logger.error( + f'Index Error raised when summing up the score in problem {self.problem.id}' + ) + self.submission.statistic_info['score'] = 0 return - self.submission.statistic_info["score"] = score + self.submission.statistic_info['score'] = score def judge(self): language = self.submission.language - sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0] + sub_config = list( + filter(lambda item: language == item['name'], SysOptions.languages) + )[0] spj_config = {} if self.problem.spj_code: for lang in SysOptions.spj_languages: - if lang["name"] == self.problem.spj_language: - spj_config = lang["spj"] + if lang['name'] == self.problem.spj_language: + spj_config = lang['spj'] break if language in self.problem.template: template = parse_problem_template(self.problem.template[language]) - code = f"{template['prepend']}\n{self.submission.code}\n{template['append']}" + code = ( + f"{template['prepend']}\n{self.submission.code}\n{template['append']}" + ) else: code = self.submission.code data = { - "language_config": sub_config["config"], - "src": code, - "max_cpu_time": self.problem.time_limit, - "max_memory": 1024 * 1024 * self.problem.memory_limit, - "test_case_id": self.problem.test_case_id, - "output": False, - "spj_version": self.problem.spj_version, - "spj_config": spj_config.get("config"), - "spj_compile_config": spj_config.get("compile"), - "spj_src": self.problem.spj_code, - "io_mode": self.problem.io_mode + 'language_config': sub_config['config'], + 'src': code, + 'max_cpu_time': self.problem.time_limit, + 'max_memory': 1024 * 1024 * self.problem.memory_limit, + 'test_case_id': self.problem.test_case_id, + 'output': False, + 'spj_version': self.problem.spj_version, + 'spj_config': spj_config.get('config'), + 'spj_compile_config': spj_config.get('compile'), + 'spj_src': self.problem.spj_code, + 'io_mode': self.problem.io_mode, } with ChooseJudgeServer() as server: if not server: - data = {"submission_id": self.submission.id, "problem_id": self.problem.id} + data = { + 'submission_id': self.submission.id, + 'problem_id': self.problem.id, + } cache.lpush(CacheKey.waiting_queue, json.dumps(data)) return - Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) - resp = self._request(urljoin(server.service_url, "/judge"), data=data) + Submission.objects.filter(id=self.submission.id).update( + result=JudgeStatus.JUDGING + ) + resp = self._request(urljoin(server.service_url, '/judge'), data=data) if not resp: - Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR) + Submission.objects.filter(id=self.submission.id).update( + result=JudgeStatus.SYSTEM_ERROR + ) return - if resp["err"]: + if resp['err']: self.submission.result = JudgeStatus.COMPILE_ERROR - self.submission.statistic_info["err_info"] = resp["data"] - self.submission.statistic_info["score"] = 0 + self.submission.statistic_info['err_info'] = resp['data'] + self.submission.statistic_info['score'] = 0 else: - resp["data"].sort(key=lambda x: int(x["test_case"])) + resp['data'].sort(key=lambda x: int(x['test_case'])) self.submission.info = resp - self._compute_statistic_info(resp["data"]) - error_test_case = list(filter(lambda case: case["result"] != 0, resp["data"])) + self._compute_statistic_info(resp['data']) + error_test_case = list( + filter(lambda case: case['result'] != 0, resp['data']) + ) # ACM模式下,多个测试点全部正确则AC,否则取第一个错误的测试点的状态 # OI模式下, 若多个测试点全部正确则AC, 若全部错误则取第一个错误测试点状态,否则为部分正确 if not error_test_case: self.submission.result = JudgeStatus.ACCEPTED - elif self.problem.rule_type == ProblemRuleType.ACM or len(error_test_case) == len(resp["data"]): - self.submission.result = error_test_case[0]["result"] + elif self.problem.rule_type == ProblemRuleType.ACM or len( + error_test_case + ) == len(resp['data']): + self.submission.result = error_test_case[0]['result'] else: self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.save() if self.contest_id: - if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \ - User.objects.get(id=self.submission.user_id).is_contest_admin(self.contest): + if ( + self.contest.status != ContestStatus.CONTEST_UNDERWAY + or User.objects.get(id=self.submission.user_id).is_contest_admin( + self.contest + ) + ): logger.info( - "Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id) + 'Contest debug mode, id: ' + + str(self.contest_id) + + ', submission id: ' + + self.submission.id + ) return with transaction.atomic(): self.update_contest_problem_status() @@ -206,88 +251,124 @@ def update_problem_status_rejudge(self): problem_id = str(self.problem.id) with transaction.atomic(): # update problem status - problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) - if self.last_result != JudgeStatus.ACCEPTED and self.submission.result == JudgeStatus.ACCEPTED: + problem = Problem.objects.select_for_update().get( + contest_id=self.contest_id, id=self.problem.id + ) + if ( + self.last_result != JudgeStatus.ACCEPTED + and self.submission.result == JudgeStatus.ACCEPTED + ): problem.accepted_number += 1 problem_info = problem.statistic_info problem_info[self.last_result] = problem_info.get(self.last_result, 1) - 1 problem_info[result] = problem_info.get(result, 0) + 1 - problem.save(update_fields=["accepted_number", "statistic_info"]) + problem.save(update_fields=['accepted_number', 'statistic_info']) - profile = User.objects.select_for_update().get(id=self.submission.user_id).userprofile + profile = ( + User.objects.select_for_update() + .get(id=self.submission.user_id) + .userprofile + ) if problem.rule_type == ProblemRuleType.ACM: - acm_problems_status = profile.acm_problems_status.get("problems", {}) - if acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: - acm_problems_status[problem_id]["status"] = self.submission.result + acm_problems_status = profile.acm_problems_status.get('problems', {}) + if acm_problems_status[problem_id]['status'] != JudgeStatus.ACCEPTED: + acm_problems_status[problem_id]['status'] = self.submission.result if self.submission.result == JudgeStatus.ACCEPTED: profile.accepted_number += 1 - profile.acm_problems_status["problems"] = acm_problems_status - profile.save(update_fields=["accepted_number", "acm_problems_status"]) + profile.acm_problems_status['problems'] = acm_problems_status + profile.save(update_fields=['accepted_number', 'acm_problems_status']) else: - oi_problems_status = profile.oi_problems_status.get("problems", {}) - score = self.submission.statistic_info["score"] - if oi_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: + oi_problems_status = profile.oi_problems_status.get('problems', {}) + score = self.submission.statistic_info['score'] + if oi_problems_status[problem_id]['status'] != JudgeStatus.ACCEPTED: # minus last time score, add this tim score - profile.add_score(this_time_score=score, - last_time_score=oi_problems_status[problem_id]["score"]) - oi_problems_status[problem_id]["score"] = score - oi_problems_status[problem_id]["status"] = self.submission.result + profile.add_score( + this_time_score=score, + last_time_score=oi_problems_status[problem_id]['score'], + ) + oi_problems_status[problem_id]['score'] = score + oi_problems_status[problem_id]['status'] = self.submission.result if self.submission.result == JudgeStatus.ACCEPTED: profile.accepted_number += 1 - profile.oi_problems_status["problems"] = oi_problems_status - profile.save(update_fields=["accepted_number", "oi_problems_status"]) + profile.oi_problems_status['problems'] = oi_problems_status + profile.save(update_fields=['accepted_number', 'oi_problems_status']) def update_problem_status(self): result = str(self.submission.result) problem_id = str(self.problem.id) with transaction.atomic(): # update problem status - problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) + problem = Problem.objects.select_for_update().get( + contest_id=self.contest_id, id=self.problem.id + ) problem.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: problem.accepted_number += 1 problem_info = problem.statistic_info problem_info[result] = problem_info.get(result, 0) + 1 - problem.save(update_fields=["accepted_number", "submission_number", "statistic_info"]) + problem.save( + update_fields=['accepted_number', 'submission_number', 'statistic_info'] + ) # update_userprofile user = User.objects.select_for_update().get(id=self.submission.user_id) user_profile = user.userprofile user_profile.submission_number += 1 if problem.rule_type == ProblemRuleType.ACM: - acm_problems_status = user_profile.acm_problems_status.get("problems", {}) + acm_problems_status = user_profile.acm_problems_status.get( + 'problems', {} + ) if problem_id not in acm_problems_status: - acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id} + acm_problems_status[problem_id] = { + 'status': self.submission.result, + '_id': self.problem._id, + } if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: - acm_problems_status[problem_id]["status"] = self.submission.result + elif acm_problems_status[problem_id]['status'] != JudgeStatus.ACCEPTED: + acm_problems_status[problem_id]['status'] = self.submission.result if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - user_profile.acm_problems_status["problems"] = acm_problems_status - user_profile.save(update_fields=["submission_number", "accepted_number", "acm_problems_status"]) + user_profile.acm_problems_status['problems'] = acm_problems_status + user_profile.save( + update_fields=[ + 'submission_number', + 'accepted_number', + 'acm_problems_status', + ] + ) else: - oi_problems_status = user_profile.oi_problems_status.get("problems", {}) - score = self.submission.statistic_info["score"] + oi_problems_status = user_profile.oi_problems_status.get('problems', {}) + score = self.submission.statistic_info['score'] if problem_id not in oi_problems_status: user_profile.add_score(score) - oi_problems_status[problem_id] = {"status": self.submission.result, - "_id": self.problem._id, - "score": score} + oi_problems_status[problem_id] = { + 'status': self.submission.result, + '_id': self.problem._id, + 'score': score, + } if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - elif oi_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: + elif oi_problems_status[problem_id]['status'] != JudgeStatus.ACCEPTED: # minus last time score, add this time score - user_profile.add_score(this_time_score=score, - last_time_score=oi_problems_status[problem_id]["score"]) - oi_problems_status[problem_id]["score"] = score - oi_problems_status[problem_id]["status"] = self.submission.result + user_profile.add_score( + this_time_score=score, + last_time_score=oi_problems_status[problem_id]['score'], + ) + oi_problems_status[problem_id]['score'] = score + oi_problems_status[problem_id]['status'] = self.submission.result if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - user_profile.oi_problems_status["problems"] = oi_problems_status - user_profile.save(update_fields=["submission_number", "accepted_number", "oi_problems_status"]) + user_profile.oi_problems_status['problems'] = oi_problems_status + user_profile.save( + update_fields=[ + 'submission_number', + 'accepted_number', + 'oi_problems_status', + ] + ) def update_contest_problem_status(self): with transaction.atomic(): @@ -295,45 +376,71 @@ def update_contest_problem_status(self): user_profile = user.userprofile problem_id = str(self.problem.id) if self.contest.rule_type == ContestRuleType.ACM: - contest_problems_status = user_profile.acm_problems_status.get("contest_problems", {}) + contest_problems_status = user_profile.acm_problems_status.get( + 'contest_problems', {} + ) if problem_id not in contest_problems_status: - contest_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id} - elif contest_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: - contest_problems_status[problem_id]["status"] = self.submission.result + contest_problems_status[problem_id] = { + 'status': self.submission.result, + '_id': self.problem._id, + } + elif ( + contest_problems_status[problem_id]['status'] + != JudgeStatus.ACCEPTED + ): + contest_problems_status[problem_id][ + 'status' + ] = self.submission.result else: # 如果已AC, 直接跳过 不计入任何计数器 return - user_profile.acm_problems_status["contest_problems"] = contest_problems_status - user_profile.save(update_fields=["acm_problems_status"]) + user_profile.acm_problems_status[ + 'contest_problems' + ] = contest_problems_status + user_profile.save(update_fields=['acm_problems_status']) elif self.contest.rule_type == ContestRuleType.OI: - contest_problems_status = user_profile.oi_problems_status.get("contest_problems", {}) - score = self.submission.statistic_info["score"] + contest_problems_status = user_profile.oi_problems_status.get( + 'contest_problems', {} + ) + score = self.submission.statistic_info['score'] if problem_id not in contest_problems_status: - contest_problems_status[problem_id] = {"status": self.submission.result, - "_id": self.problem._id, - "score": score} + contest_problems_status[problem_id] = { + 'status': self.submission.result, + '_id': self.problem._id, + 'score': score, + } else: - contest_problems_status[problem_id]["score"] = score - contest_problems_status[problem_id]["status"] = self.submission.result - user_profile.oi_problems_status["contest_problems"] = contest_problems_status - user_profile.save(update_fields=["oi_problems_status"]) - - problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) + contest_problems_status[problem_id]['score'] = score + contest_problems_status[problem_id][ + 'status' + ] = self.submission.result + user_profile.oi_problems_status[ + 'contest_problems' + ] = contest_problems_status + user_profile.save(update_fields=['oi_problems_status']) + + problem = Problem.objects.select_for_update().get( + contest_id=self.contest_id, id=self.problem.id + ) result = str(self.submission.result) problem_info = problem.statistic_info problem_info[result] = problem_info.get(result, 0) + 1 problem.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: problem.accepted_number += 1 - problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"]) + problem.save( + update_fields=['submission_number', 'accepted_number', 'statistic_info'] + ) def update_contest_rank(self): if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: - cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") + cache.delete(f'{CacheKey.contest_rank_cache}:{self.contest.id}') def get_rank(model): - return model.objects.select_for_update().get(user_id=self.submission.user_id, contest=self.contest) + return model.objects.select_for_update().get( + user_id=self.submission.user_id, contest=self.contest + ) if self.contest.rule_type == ContestRuleType.ACM: model = ACMContestRank @@ -346,7 +453,9 @@ def get_rank(model): rank = get_rank(model) except model.DoesNotExist: try: - model.objects.create(user_id=self.submission.user_id, contest=self.contest) + model.objects.create( + user_id=self.submission.user_id, contest=self.contest + ) rank = get_rank(model) except IntegrityError: rank = get_rank(model) @@ -355,45 +464,56 @@ def get_rank(model): def _update_acm_contest_rank(self, rank): info = rank.submission_info.get(str(self.submission.problem_id)) # 因前面更改过,这里需要重新获取 - problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) + problem = Problem.objects.select_for_update().get( + contest_id=self.contest_id, id=self.problem.id + ) # 此题提交过 if info: - if info["is_ac"]: + if info['is_ac']: return rank.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: rank.accepted_number += 1 - info["is_ac"] = True - info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds() - rank.total_time += info["ac_time"] + info["error_number"] * 20 * 60 + info['is_ac'] = True + info['ac_time'] = ( + self.submission.create_time - self.contest.start_time + ).total_seconds() + rank.total_time += info['ac_time'] + info['error_number'] * 20 * 60 if problem.accepted_number == 1: - info["is_first_ac"] = True + info['is_first_ac'] = True elif self.submission.result != JudgeStatus.COMPILE_ERROR: - info["error_number"] += 1 + info['error_number'] += 1 # 第一次提交 else: rank.submission_number += 1 - info = {"is_ac": False, "ac_time": 0, "error_number": 0, "is_first_ac": False} + info = { + 'is_ac': False, + 'ac_time': 0, + 'error_number': 0, + 'is_first_ac': False, + } if self.submission.result == JudgeStatus.ACCEPTED: rank.accepted_number += 1 - info["is_ac"] = True - info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds() - rank.total_time += info["ac_time"] + info['is_ac'] = True + info['ac_time'] = ( + self.submission.create_time - self.contest.start_time + ).total_seconds() + rank.total_time += info['ac_time'] if problem.accepted_number == 1: - info["is_first_ac"] = True + info['is_first_ac'] = True elif self.submission.result != JudgeStatus.COMPILE_ERROR: - info["error_number"] = 1 + info['error_number'] = 1 rank.submission_info[str(self.submission.problem_id)] = info rank.save() def _update_oi_contest_rank(self, rank): problem_id = str(self.submission.problem_id) - current_score = self.submission.statistic_info["score"] + current_score = self.submission.statistic_info['score'] last_score = rank.submission_info.get(problem_id) if last_score: rank.total_score = rank.total_score - last_score + current_score diff --git a/judge/languages.py b/judge/languages.py index 7e0c69714..9580edfce 100644 --- a/judge/languages.py +++ b/judge/languages.py @@ -1,10 +1,10 @@ from problem.models import ProblemIOMode -default_env = ["LANG=en_US.UTF-8", "LANGUAGE=en_US:en", "LC_ALL=en_US.UTF-8"] +default_env = ['LANG=en_US.UTF-8', 'LANGUAGE=en_US:en', 'LC_ALL=en_US.UTF-8'] _c_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN #include //PREPEND END @@ -20,38 +20,41 @@ return 0; } //APPEND END""", - "compile": { - "src_name": "main.c", - "exe_name": "main", - "max_cpu_time": 3000, - "max_real_time": 10000, - "max_memory": 256 * 1024 * 1024, - "compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}", + 'compile': { + 'src_name': 'main.c', + 'exe_name': 'main', + 'max_cpu_time': 3000, + 'max_real_time': 10000, + 'max_memory': 256 * 1024 * 1024, + 'compile_command': '/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}', + }, + 'run': { + 'command': '{exe_path}', + 'seccomp_rule': { + ProblemIOMode.standard: 'c_cpp', + ProblemIOMode.file: 'c_cpp_file_io', + }, + 'env': default_env, }, - "run": { - "command": "{exe_path}", - "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"}, - "env": default_env - } } _c_lang_spj_compile = { - "src_name": "spj-{spj_version}.c", - "exe_name": "spj-{spj_version}", - "max_cpu_time": 3000, - "max_real_time": 10000, - "max_memory": 1024 * 1024 * 1024, - "compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}" + 'src_name': 'spj-{spj_version}.c', + 'exe_name': 'spj-{spj_version}', + 'max_cpu_time': 3000, + 'max_real_time': 10000, + 'max_memory': 1024 * 1024 * 1024, + 'compile_command': '/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}', } _c_lang_spj_config = { - "exe_name": "spj-{spj_version}", - "command": "{exe_path} {in_file_path} {user_out_file_path}", - "seccomp_rule": "c_cpp" + 'exe_name': 'spj-{spj_version}', + 'command': '{exe_path} {in_file_path} {user_out_file_path}', + 'seccomp_rule': 'c_cpp', } _cpp_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN #include //PREPEND END @@ -67,38 +70,41 @@ return 0; } //APPEND END""", - "compile": { - "src_name": "main.cpp", - "exe_name": "main", - "max_cpu_time": 10000, - "max_real_time": 20000, - "max_memory": 1024 * 1024 * 1024, - "compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}", + 'compile': { + 'src_name': 'main.cpp', + 'exe_name': 'main', + 'max_cpu_time': 10000, + 'max_real_time': 20000, + 'max_memory': 1024 * 1024 * 1024, + 'compile_command': '/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}', + }, + 'run': { + 'command': '{exe_path}', + 'seccomp_rule': { + ProblemIOMode.standard: 'c_cpp', + ProblemIOMode.file: 'c_cpp_file_io', + }, + 'env': default_env, }, - "run": { - "command": "{exe_path}", - "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"}, - "env": default_env - } } _cpp_lang_spj_compile = { - "src_name": "spj-{spj_version}.cpp", - "exe_name": "spj-{spj_version}", - "max_cpu_time": 10000, - "max_real_time": 20000, - "max_memory": 1024 * 1024 * 1024, - "compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}" + 'src_name': 'spj-{spj_version}.cpp', + 'exe_name': 'spj-{spj_version}', + 'max_cpu_time': 10000, + 'max_real_time': 20000, + 'max_memory': 1024 * 1024 * 1024, + 'compile_command': '/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}', } _cpp_lang_spj_config = { - "exe_name": "spj-{spj_version}", - "command": "{exe_path} {in_file_path} {user_out_file_path}", - "seccomp_rule": "c_cpp" + 'exe_name': 'spj-{spj_version}', + 'command': '{exe_path} {in_file_path} {user_out_file_path}', + 'seccomp_rule': 'c_cpp', } _java_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN class Main { //PREPEND END @@ -114,24 +120,24 @@ class Main { } } //APPEND END""", - "compile": { - "src_name": "Main.java", - "exe_name": "Main", - "max_cpu_time": 5000, - "max_real_time": 10000, - "max_memory": -1, - "compile_command": "/usr/bin/javac {src_path} -d {exe_dir}" - }, - "run": { - "command": "/usr/bin/java -cp {exe_dir} -XX:MaxRAM={max_memory}k Main", - "seccomp_rule": None, - "env": default_env, - "memory_limit_check_only": 1 - } + 'compile': { + 'src_name': 'Main.java', + 'exe_name': 'Main', + 'max_cpu_time': 5000, + 'max_real_time': 10000, + 'max_memory': -1, + 'compile_command': '/usr/bin/javac {src_path} -d {exe_dir}', + }, + 'run': { + 'command': '/usr/bin/java -cp {exe_dir} -XX:MaxRAM={max_memory}k Main', + 'seccomp_rule': None, + 'env': default_env, + 'memory_limit_check_only': 1, + }, } _py3_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN //PREPEND END //TEMPLATE BEGIN @@ -143,23 +149,23 @@ def add(a, b): //APPEND BEGIN print(add(1, 2)) //APPEND END""", - "compile": { - "src_name": "solution.py", - "exe_name": "solution.py", - "max_cpu_time": 3000, - "max_real_time": 10000, - "max_memory": 128 * 1024 * 1024, - "compile_command": "/usr/bin/python3 -m py_compile {src_path}", + 'compile': { + 'src_name': 'solution.py', + 'exe_name': 'solution.py', + 'max_cpu_time': 3000, + 'max_real_time': 10000, + 'max_memory': 128 * 1024 * 1024, + 'compile_command': '/usr/bin/python3 -m py_compile {src_path}', + }, + 'run': { + 'command': '/usr/bin/python3 -BS {exe_path}', + 'seccomp_rule': 'general', + 'env': default_env, }, - "run": { - "command": "/usr/bin/python3 -BS {exe_path}", - "seccomp_rule": "general", - "env": default_env - } } _go_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN package main import "fmt" @@ -176,25 +182,25 @@ def add(a, b): fmt.Println(add(1, 2)) } //APPEND END""", - "compile": { - "src_name": "main.go", - "exe_name": "main", - "max_cpu_time": 3000, - "max_real_time": 5000, - "max_memory": 1024 * 1024 * 1024, - "compile_command": "/usr/bin/go build -o {exe_path} {src_path}", - "env": ["GOCACHE=/tmp", "GOPATH=/tmp", "GOMAXPROCS=1"] + default_env - }, - "run": { - "command": "{exe_path}", - "seccomp_rule": "golang", - "env": ["GOMAXPROCS=1"] + default_env, - "memory_limit_check_only": 1 - } + 'compile': { + 'src_name': 'main.go', + 'exe_name': 'main', + 'max_cpu_time': 3000, + 'max_real_time': 5000, + 'max_memory': 1024 * 1024 * 1024, + 'compile_command': '/usr/bin/go build -o {exe_path} {src_path}', + 'env': ['GOCACHE=/tmp', 'GOPATH=/tmp', 'GOMAXPROCS=1'] + default_env, + }, + 'run': { + 'command': '{exe_path}', + 'seccomp_rule': 'golang', + 'env': ['GOMAXPROCS=1'] + default_env, + 'memory_limit_check_only': 1, + }, } _node_lang_config = { - "template": """//PREPEND BEGIN + 'template': """//PREPEND BEGIN //PREPEND END //TEMPLATE BEGIN @@ -206,30 +212,60 @@ def add(a, b): //APPEND BEGIN console.log(add(1, 2)) //APPEND END""", - "compile": { - "src_name": "main.js", - "exe_name": "main.js", - "max_cpu_time": 3000, - "max_real_time": 5000, - "max_memory": 1024 * 1024 * 1024, - "compile_command": "/usr/bin/node --check {src_path}", - "env": default_env - }, - "run": { - "command": "/usr/bin/node {exe_path}", - "seccomp_rule": "node", - "env": default_env, - "memory_limit_check_only": 1 - } + 'compile': { + 'src_name': 'main.js', + 'exe_name': 'main.js', + 'max_cpu_time': 3000, + 'max_real_time': 5000, + 'max_memory': 1024 * 1024 * 1024, + 'compile_command': '/usr/bin/node --check {src_path}', + 'env': default_env, + }, + 'run': { + 'command': '/usr/bin/node {exe_path}', + 'seccomp_rule': 'node', + 'env': default_env, + 'memory_limit_check_only': 1, + }, } languages = [ - {"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc", - "spj": {"compile": _c_lang_spj_compile, "config": _c_lang_spj_config}}, - {"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src", - "spj": {"compile": _cpp_lang_spj_compile, "config": _cpp_lang_spj_config}}, - {"config": _java_lang_config, "name": "Java", "description": "Temurin 21", "content_type": "text/x-java"}, - {"config": _py3_lang_config, "name": "Python3", "description": "Python 3.12", "content_type": "text/x-python"}, - {"config": _go_lang_config, "name": "Golang", "description": "Golang 1.22", "content_type": "text/x-go"}, - {"config": _node_lang_config, "name": "JavaScript", "description": "Node.js 20", "content_type": "text/javascript"}, + { + 'config': _c_lang_config, + 'name': 'C', + 'description': 'GCC 13', + 'content_type': 'text/x-csrc', + 'spj': {'compile': _c_lang_spj_compile, 'config': _c_lang_spj_config}, + }, + { + 'config': _cpp_lang_config, + 'name': 'C++', + 'description': 'GCC 13', + 'content_type': 'text/x-c++src', + 'spj': {'compile': _cpp_lang_spj_compile, 'config': _cpp_lang_spj_config}, + }, + { + 'config': _java_lang_config, + 'name': 'Java', + 'description': 'Temurin 21', + 'content_type': 'text/x-java', + }, + { + 'config': _py3_lang_config, + 'name': 'Python3', + 'description': 'Python 3.12', + 'content_type': 'text/x-python', + }, + { + 'config': _go_lang_config, + 'name': 'Golang', + 'description': 'Golang 1.22', + 'content_type': 'text/x-go', + }, + { + 'config': _node_lang_config, + 'name': 'JavaScript', + 'description': 'Node.js 20', + 'content_type': 'text/javascript', + }, ] diff --git a/manage.py b/manage.py index 16b46a54b..23ff0ca76 100755 --- a/manage.py +++ b/manage.py @@ -2,11 +2,12 @@ import os import sys -if __name__ == "__main__": - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") +if __name__ == '__main__': + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings') from django.core.management import execute_from_command_line import django - sys.stdout.write("Django VERSION " + str(django.VERSION) + "\n") + + sys.stdout.write('Django VERSION ' + str(django.VERSION) + '\n') execute_from_command_line(sys.argv) diff --git a/oj/dev_settings.py b/oj/dev_settings.py index f79cc8763..7e52a374c 100644 --- a/oj/dev_settings.py +++ b/oj/dev_settings.py @@ -11,18 +11,18 @@ 'PORT': get_env('POSTGRES_PORT', '5435'), 'NAME': get_env('POSTGRES_DB', 'onlinejudge'), 'USER': get_env('POSTGRES_USER', 'onlinejudge'), - 'PASSWORD': get_env('POSTGRES_PASSWORD', 'onlinejudge') + 'PASSWORD': get_env('POSTGRES_PASSWORD', 'onlinejudge'), } } REDIS_CONF = { 'host': get_env('REDIS_HOST', '127.0.0.1'), - 'port': get_env('REDIS_PORT', '6380') + 'port': get_env('REDIS_PORT', '6380'), } DEBUG = True -ALLOWED_HOSTS = ["*"] +ALLOWED_HOSTS = ['*'] -DATA_DIR = f"{BASE_DIR}/data" +DATA_DIR = f'{BASE_DIR}/data' diff --git a/oj/production_settings.py b/oj/production_settings.py index e1bcde80f..e06810d45 100644 --- a/oj/production_settings.py +++ b/oj/production_settings.py @@ -3,21 +3,21 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql_psycopg2', - 'HOST': get_env("POSTGRES_HOST", "oj-postgres"), - 'PORT': get_env("POSTGRES_PORT", "5432"), - 'NAME': get_env("POSTGRES_DB"), - 'USER': get_env("POSTGRES_USER"), - 'PASSWORD': get_env("POSTGRES_PASSWORD") + 'HOST': get_env('POSTGRES_HOST', 'oj-postgres'), + 'PORT': get_env('POSTGRES_PORT', '5432'), + 'NAME': get_env('POSTGRES_DB'), + 'USER': get_env('POSTGRES_USER'), + 'PASSWORD': get_env('POSTGRES_PASSWORD'), } } REDIS_CONF = { - "host": get_env("REDIS_HOST", "oj-redis"), - "port": get_env("REDIS_PORT", "6379") + 'host': get_env('REDIS_HOST', 'oj-redis'), + 'port': get_env('REDIS_PORT', '6379'), } DEBUG = False ALLOWED_HOSTS = ['*'] -DATA_DIR = "/data" +DATA_DIR = '/data' diff --git a/oj/settings.py b/oj/settings.py index 7d3df1607..b09c1f2ed 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -10,17 +10,15 @@ https://docs.djangoproject.com/en/1.8/ref/settings/ """ import os -import raven -from copy import deepcopy from utils.shortcuts import get_env -production_env = get_env("OJ_ENV", "dev") == "production" +production_env = get_env('OJ_ENV', 'dev') == 'production' if production_env: from .production_settings import * else: from .dev_settings import * -with open(os.path.join(DATA_DIR, "config", "secret.key"), "r") as f: +with open(os.path.join(DATA_DIR, 'config', 'secret.key'), 'r') as f: SECRET_KEY = f.read() BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -125,72 +123,70 @@ AUTH_USER_MODEL = 'account.User' -TEST_CASE_DIR = os.path.join(DATA_DIR, "test_case") -LOG_PATH = os.path.join(DATA_DIR, "log") +TEST_CASE_DIR = os.path.join(DATA_DIR, 'test_case') +LOG_PATH = os.path.join(DATA_DIR, 'log') -AVATAR_URI_PREFIX = "/public/avatar" -AVATAR_UPLOAD_DIR = f"{DATA_DIR}{AVATAR_URI_PREFIX}" +AVATAR_URI_PREFIX = '/public/avatar' +AVATAR_UPLOAD_DIR = f'{DATA_DIR}{AVATAR_URI_PREFIX}' -UPLOAD_PREFIX = "/public/upload" -UPLOAD_DIR = f"{DATA_DIR}{UPLOAD_PREFIX}" +UPLOAD_PREFIX = '/public/upload' +UPLOAD_DIR = f'{DATA_DIR}{UPLOAD_PREFIX}' -STATICFILES_DIRS = [os.path.join(DATA_DIR, "public")] +STATICFILES_DIRS = [os.path.join(DATA_DIR, 'public')] LOGGING_HANDLERS = ['console', 'sentry'] if production_env else ['console'] LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'standard': { - 'format': '[%(asctime)s] - [%(levelname)s] - [%(name)s:%(lineno)d] - %(message)s', - 'datefmt': '%Y-%m-%d %H:%M:%S' - } - }, - 'handlers': { - 'console': { - 'level': 'DEBUG', - 'class': 'logging.StreamHandler', - 'formatter': 'standard' - }, - 'sentry': { - 'level': 'ERROR', - 'class': 'raven.contrib.django.raven_compat.handlers.SentryHandler', - 'formatter': 'standard' - } - }, - 'loggers': { - 'django.request': { - 'handlers': LOGGING_HANDLERS, - 'level': 'ERROR', - 'propagate': True, - }, - 'django.db.backends': { - 'handlers': LOGGING_HANDLERS, - 'level': 'ERROR', - 'propagate': True, - }, + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '[%(asctime)s] - [%(levelname)s] - [%(name)s:%(lineno)d] - %(message)s', + 'datefmt': '%Y-%m-%d %H:%M:%S', + } + }, + 'handlers': { + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'standard', + }, + 'sentry': { + 'level': 'ERROR', + 'class': 'raven.contrib.django.raven_compat.handlers.SentryHandler', + 'formatter': 'standard', + }, + }, + 'loggers': { + 'django.request': { + 'handlers': LOGGING_HANDLERS, + 'level': 'ERROR', + 'propagate': True, + }, + 'django.db.backends': { + 'handlers': LOGGING_HANDLERS, + 'level': 'ERROR', + 'propagate': True, + }, 'dramatiq': { 'handlers': LOGGING_HANDLERS, 'level': 'DEBUG', 'propagate': False, }, - '': { - 'handlers': LOGGING_HANDLERS, - 'level': 'WARNING', - 'propagate': True, - } - }, + '': { + 'handlers': LOGGING_HANDLERS, + 'level': 'WARNING', + 'propagate': True, + }, + }, } REST_FRAMEWORK = { 'TEST_REQUEST_DEFAULT_FORMAT': 'json', - 'DEFAULT_RENDERER_CLASSES': ( - 'rest_framework.renderers.JSONRenderer', - ) + 'DEFAULT_RENDERER_CLASSES': ('rest_framework.renderers.JSONRenderer',), } -REDIS_URL = "redis://%s:%s" % (REDIS_CONF["host"], REDIS_CONF["port"]) +REDIS_URL = 'redis://%s:%s' % (REDIS_CONF['host'], REDIS_CONF['port']) def redis_config(db): @@ -198,49 +194,45 @@ def make_key(key, key_prefix, version): return key return { - "BACKEND": "utils.cache.MyRedisCache", - "LOCATION": f"{REDIS_URL}/{db}", - "TIMEOUT": None, - "KEY_PREFIX": "", - "KEY_FUNCTION": make_key + 'BACKEND': 'utils.cache.MyRedisCache', + 'LOCATION': f'{REDIS_URL}/{db}', + 'TIMEOUT': None, + 'KEY_PREFIX': '', + 'KEY_FUNCTION': make_key, } -CACHES = { - "default": redis_config(db=1) -} +CACHES = {'default': redis_config(db=1)} -SESSION_ENGINE = "django.contrib.sessions.backends.cache" -SESSION_CACHE_ALIAS = "default" +SESSION_ENGINE = 'django.contrib.sessions.backends.cache' +SESSION_CACHE_ALIAS = 'default' DRAMATIQ_BROKER = { - "BROKER": "dramatiq.brokers.redis.RedisBroker", - "OPTIONS": { - "url": f"{REDIS_URL}/4", + 'BROKER': 'dramatiq.brokers.redis.RedisBroker', + 'OPTIONS': { + 'url': f'{REDIS_URL}/4', }, - "MIDDLEWARE": [ + 'MIDDLEWARE': [ # "dramatiq.middleware.Prometheus", - "dramatiq.middleware.AgeLimit", - "dramatiq.middleware.TimeLimit", - "dramatiq.middleware.Callbacks", - "dramatiq.middleware.Retries", + 'dramatiq.middleware.AgeLimit', + 'dramatiq.middleware.TimeLimit', + 'dramatiq.middleware.Callbacks', + 'dramatiq.middleware.Retries', # "django_dramatiq.middleware.AdminMiddleware", - "django_dramatiq.middleware.DbConnectionsMiddleware" - ] + 'django_dramatiq.middleware.DbConnectionsMiddleware', + ], } DRAMATIQ_RESULT_BACKEND = { - "BACKEND": "dramatiq.results.backends.redis.RedisBackend", - "BACKEND_OPTIONS": { - "url": f"{REDIS_URL}/4", + 'BACKEND': 'dramatiq.results.backends.redis.RedisBackend', + 'BACKEND_OPTIONS': { + 'url': f'{REDIS_URL}/4', }, - "MIDDLEWARE_OPTIONS": { - "result_ttl": None - } + 'MIDDLEWARE_OPTIONS': {'result_ttl': None}, } RAVEN_CONFIG = { 'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057' } -DEFAULT_AUTO_FIELD='django.db.models.AutoField' +DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' diff --git a/oj/urls.py b/oj/urls.py index c626656bd..62522161d 100644 --- a/oj/urls.py +++ b/oj/urls.py @@ -1,17 +1,17 @@ from django.conf.urls import include, url urlpatterns = [ - url(r"^api/", include("account.urls.oj")), - url(r"^api/admin/", include("account.urls.admin")), - url(r"^api/", include("announcement.urls.oj")), - url(r"^api/admin/", include("announcement.urls.admin")), - url(r"^api/", include("conf.urls.oj")), - url(r"^api/admin/", include("conf.urls.admin")), - url(r"^api/", include("problem.urls.oj")), - url(r"^api/admin/", include("problem.urls.admin")), - url(r"^api/", include("contest.urls.oj")), - url(r"^api/admin/", include("contest.urls.admin")), - url(r"^api/", include("submission.urls.oj")), - url(r"^api/admin/", include("submission.urls.admin")), - url(r"^api/admin/", include("utils.urls")), + url(r'^api/', include('account.urls.oj')), + url(r'^api/admin/', include('account.urls.admin')), + url(r'^api/', include('announcement.urls.oj')), + url(r'^api/admin/', include('announcement.urls.admin')), + url(r'^api/', include('conf.urls.oj')), + url(r'^api/admin/', include('conf.urls.admin')), + url(r'^api/', include('problem.urls.oj')), + url(r'^api/admin/', include('problem.urls.admin')), + url(r'^api/', include('contest.urls.oj')), + url(r'^api/admin/', include('contest.urls.admin')), + url(r'^api/', include('submission.urls.oj')), + url(r'^api/admin/', include('submission.urls.admin')), + url(r'^api/admin/', include('utils.urls')), ] diff --git a/oj/wsgi.py b/oj/wsgi.py index c08519117..49df340b7 100644 --- a/oj/wsgi.py +++ b/oj/wsgi.py @@ -11,6 +11,6 @@ from django.core.wsgi import get_wsgi_application -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings') application = get_wsgi_application() diff --git a/options/migrations/0001_initial.py b/options/migrations/0001_initial.py index a109c9155..ee9035b7a 100644 --- a/options/migrations/0001_initial.py +++ b/options/migrations/0001_initial.py @@ -7,17 +7,23 @@ class Migration(migrations.Migration): - initial = True - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( name='SysOptions', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('key', models.CharField(db_index=True, max_length=128, unique=True)), ('value', django.contrib.postgres.fields.jsonb.JSONField()), ], diff --git a/options/migrations/0002_auto_20180501_0436.py b/options/migrations/0002_auto_20180501_0436.py index 2e76be77a..8fb3bc568 100644 --- a/options/migrations/0002_auto_20180501_0436.py +++ b/options/migrations/0002_auto_20180501_0436.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('options', '0001_initial'), ] diff --git a/options/migrations/0003_migrate_languages_options.py b/options/migrations/0003_migrate_languages_options.py index 0f8f2812a..196d725c2 100644 --- a/options/migrations/0003_migrate_languages_options.py +++ b/options/migrations/0003_migrate_languages_options.py @@ -2,17 +2,18 @@ # Generated by Django 1.11.3 on 2018-05-01 04:36 from __future__ import unicode_literals -from django.db import migrations, models +from django.db import migrations class Migration(migrations.Migration): - dependencies = [ ('options', '0002_auto_20180501_0436'), ] operations = [ - migrations.RunSQL(""" + migrations.RunSQL( + """ DELETE FROM options_sysoptions WHERE key = 'languages'; - """) + """ + ) ] diff --git a/options/options.py b/options/options.py index f1843fd38..7c9111a8a 100644 --- a/options/options.py +++ b/options/options.py @@ -17,6 +17,7 @@ class my_property: 2. ttl is callable,条件缓存 3. 缓存 ttl 秒 """ + def __init__(self, func=None, fset=None, ttl=None): self.fset = fset self.local = threading.local() @@ -32,9 +33,9 @@ def _check_ttl(self, value): def _check_timeout(self, value): if not isinstance(value, int): - raise ValueError(f"Invalid timeout type: {type(value)}") + raise ValueError(f'Invalid timeout type: {type(value)}') if value < 0: - raise ValueError("Invalid timeout value, it must >= 0") + raise ValueError('Invalid timeout value, it must >= 0') def __get__(self, obj, cls): if obj is None: @@ -42,7 +43,7 @@ def __get__(self, obj, cls): now = time.time() if self.ttl: - if hasattr(self.local, "value"): + if hasattr(self.local, 'value'): value, expire_at = self.local.value if now < expire_at: return value @@ -70,14 +71,14 @@ def __set__(self, obj, value): if not self.fset: raise AttributeError("can't set attribute") self.fset(obj, value) - if hasattr(self.local, "value"): + if hasattr(self.local, 'value'): del self.local.value def setter(self, func): self.fset = func return self - def __call__(self, func, *args, **kwargs) -> "my_property": + def __call__(self, func, *args, **kwargs) -> 'my_property': if self.func is None: self.func = func functools.update_wrapper(self, func) @@ -88,41 +89,43 @@ def __call__(self, func, *args, **kwargs) -> "my_property": def default_token(): - token = os.environ.get("JUDGE_SERVER_TOKEN") + token = os.environ.get('JUDGE_SERVER_TOKEN') return token if token else rand_str() class OptionKeys: - website_base_url = "website_base_url" - website_name = "website_name" - website_name_shortcut = "website_name_shortcut" - website_footer = "website_footer" - allow_register = "allow_register" - submission_list_show_all = "submission_list_show_all" - smtp_config = "smtp_config" - judge_server_token = "judge_server_token" - throttling = "throttling" - languages = "languages" + website_base_url = 'website_base_url' + website_name = 'website_name' + website_name_shortcut = 'website_name_shortcut' + website_footer = 'website_footer' + allow_register = 'allow_register' + submission_list_show_all = 'submission_list_show_all' + smtp_config = 'smtp_config' + judge_server_token = 'judge_server_token' + throttling = 'throttling' + languages = 'languages' class OptionDefaultValue: - website_base_url = "http://127.0.0.1" - website_name = "Online Judge" - website_name_shortcut = "oj" - website_footer = "Online Judge Footer" + website_base_url = 'http://127.0.0.1' + website_name = 'Online Judge' + website_name_shortcut = 'oj' + website_footer = 'Online Judge Footer' allow_register = True submission_list_show_all = True smtp_config = {} judge_server_token = default_token - throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50}, - "user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}} + throttling = { + 'ip': {'capacity': 100, 'fill_rate': 0.1, 'default_capacity': 50}, + 'user': {'capacity': 20, 'fill_rate': 0.03, 'default_capacity': 10}, + } languages = languages class _SysOptionsMeta(type): @classmethod def _get_keys(cls): - return [key for key in OptionKeys.__dict__ if not key.startswith("__")] + return [key for key in OptionKeys.__dict__ if not key.startswith('__')] @classmethod def _init_option(mcs): @@ -263,15 +266,15 @@ def languages(cls, value): @my_property(ttl=DEFAULT_SHORT_TTL) def spj_languages(cls): - return [item for item in cls.languages if "spj" in item] + return [item for item in cls.languages if 'spj' in item] @my_property(ttl=DEFAULT_SHORT_TTL) def language_names(cls): - return [item["name"] for item in cls.languages] + return [item['name'] for item in cls.languages] @my_property(ttl=DEFAULT_SHORT_TTL) def spj_language_names(cls): - return [item["name"] for item in cls.languages if "spj" in item] + return [item['name'] for item in cls.languages if 'spj' in item] def reset_languages(cls): cls.languages = languages diff --git a/problem/migrations/0001_initial.py b/problem/migrations/0001_initial.py index bdcf2ee7c..5323d0365 100644 --- a/problem/migrations/0001_initial.py +++ b/problem/migrations/0001_initial.py @@ -10,7 +10,6 @@ class Migration(migrations.Migration): - initial = True dependencies = [ @@ -21,7 +20,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='Problem', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=128)), ('description', utils.models.RichTextField()), ('input_description', utils.models.RichTextField()), @@ -37,7 +44,10 @@ class Migration(migrations.Migration): ('time_limit', models.IntegerField()), ('memory_limit', models.IntegerField()), ('spj', models.BooleanField(default=False)), - ('spj_language', models.CharField(blank=True, max_length=32, null=True)), + ( + 'spj_language', + models.CharField(blank=True, max_length=32, null=True), + ), ('spj_code', models.TextField(blank=True, null=True)), ('spj_version', models.CharField(blank=True, max_length=32, null=True)), ('rule_type', models.CharField(max_length=32)), @@ -46,7 +56,13 @@ class Migration(migrations.Migration): ('source', models.CharField(blank=True, max_length=200, null=True)), ('total_submit_number', models.IntegerField(default=0)), ('total_accepted_number', models.IntegerField(default=0)), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ 'db_table': 'problem', @@ -56,7 +72,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='ProblemTag', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('name', models.CharField(max_length=30)), ], options={ diff --git a/problem/migrations/0002_problem__id.py b/problem/migrations/0002_problem__id.py index 7c5d61dc0..343a71689 100644 --- a/problem/migrations/0002_problem__id.py +++ b/problem/migrations/0002_problem__id.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0001_initial'), ] @@ -15,7 +14,9 @@ class Migration(migrations.Migration): migrations.AddField( model_name='problem', name='_id', - field=models.CharField(db_index=True, default='1', max_length=24, unique=True), + field=models.CharField( + db_index=True, default='1', max_length=24, unique=True + ), preserve_default=False, ), ] diff --git a/problem/migrations/0003_auto_20170217_0820.py b/problem/migrations/0003_auto_20170217_0820.py index 70d66f261..98c49a896 100644 --- a/problem/migrations/0003_auto_20170217_0820.py +++ b/problem/migrations/0003_auto_20170217_0820.py @@ -10,7 +10,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0003_auto_20170217_0820'), migrations.swappable_dependency(settings.AUTH_USER_MODEL), @@ -21,7 +20,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name='ContestProblem', fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ( + 'id', + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), ('title', models.CharField(max_length=128)), ('description', utils.models.RichTextField()), ('input_description', utils.models.RichTextField()), @@ -37,7 +44,10 @@ class Migration(migrations.Migration): ('time_limit', models.IntegerField()), ('memory_limit', models.IntegerField()), ('spj', models.BooleanField(default=False)), - ('spj_language', models.CharField(blank=True, max_length=32, null=True)), + ( + 'spj_language', + models.CharField(blank=True, max_length=32, null=True), + ), ('spj_code', models.TextField(blank=True, null=True)), ('spj_version', models.CharField(blank=True, max_length=32, null=True)), ('rule_type', models.CharField(max_length=32)), @@ -48,8 +58,20 @@ class Migration(migrations.Migration): ('total_accepted_number', models.IntegerField(default=0)), ('_id', models.CharField(db_index=True, max_length=24)), ('is_public', models.BooleanField(default=False)), - ('contest', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contest.Contest')), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + 'contest', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), + ), + ( + 'created_by', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), ('tags', models.ManyToManyField(to='problem.ProblemTag')), ], options={ diff --git a/problem/migrations/0004_auto_20170501_0637.py b/problem/migrations/0004_auto_20170501_0637.py index 5d55ace8b..c4576d4ae 100644 --- a/problem/migrations/0004_auto_20170501_0637.py +++ b/problem/migrations/0004_auto_20170501_0637.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0003_auto_20170217_0820'), ] diff --git a/problem/migrations/0005_auto_20170815_1258.py b/problem/migrations/0005_auto_20170815_1258.py index 194969662..a5812c700 100644 --- a/problem/migrations/0005_auto_20170815_1258.py +++ b/problem/migrations/0005_auto_20170815_1258.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0004_auto_20170501_0637'), ] diff --git a/problem/migrations/0006_auto_20170823_0918.py b/problem/migrations/0006_auto_20170823_0918.py index 933070f8e..cebfd688a 100644 --- a/problem/migrations/0006_auto_20170823_0918.py +++ b/problem/migrations/0006_auto_20170823_0918.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0005_auto_20170815_1258'), ] diff --git a/problem/migrations/0008_auto_20170923_1318.py b/problem/migrations/0008_auto_20170923_1318.py index 4f5bb9963..ca4be825a 100644 --- a/problem/migrations/0008_auto_20170923_1318.py +++ b/problem/migrations/0008_auto_20170923_1318.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('contest', '0005_auto_20170823_0918'), ('problem', '0006_auto_20170823_0918'), @@ -43,7 +42,12 @@ class Migration(migrations.Migration): migrations.AddField( model_name='problem', name='contest', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'), + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), preserve_default=False, ), migrations.AddField( diff --git a/problem/migrations/0009_auto_20171011_1214.py b/problem/migrations/0009_auto_20171011_1214.py index e219ff297..dd6e630a2 100644 --- a/problem/migrations/0009_auto_20171011_1214.py +++ b/problem/migrations/0009_auto_20171011_1214.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0008_auto_20170923_1318'), ] diff --git a/problem/migrations/0010_problem_spj_compile_ok.py b/problem/migrations/0010_problem_spj_compile_ok.py index 0df1b36b1..6104be822 100644 --- a/problem/migrations/0010_problem_spj_compile_ok.py +++ b/problem/migrations/0010_problem_spj_compile_ok.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0009_auto_20171011_1214'), ] diff --git a/problem/migrations/0011_fix_problem_ac_count.py b/problem/migrations/0011_fix_problem_ac_count.py index 0550f9f0f..57005a73b 100644 --- a/problem/migrations/0011_fix_problem_ac_count.py +++ b/problem/migrations/0011_fix_problem_ac_count.py @@ -6,22 +6,28 @@ def fix_problem_count_bugs(apps, schema_editor): - Submission = apps.get_model("submission", "Submission") - Problem = apps.get_model("problem", "Problem") + Submission = apps.get_model('submission', 'Submission') + Problem = apps.get_model('problem', 'Problem') for item in Problem.objects.filter(contest__isnull=True): submissions = Submission.objects.filter(problem=item) item.submission_number = submissions.count() - results_count = submissions.values('result').annotate(count=Count('result')).order_by('result') + results_count = ( + submissions.values('result') + .annotate(count=Count('result')) + .order_by('result') + ) info = dict() item.accepted_number = 0 for stat in results_count: - result = stat["result"] + result = stat['result'] if result == 0: - item.accepted_number = stat["count"] - info[str(result)] = stat["count"] + item.accepted_number = stat['count'] + info[str(result)] = stat['count'] item.statistic_info = info - item.save(update_fields=["submission_number", "accepted_number", "statistic_info"]) + item.save( + update_fields=['submission_number', 'accepted_number', 'statistic_info'] + ) class Migration(migrations.Migration): @@ -31,5 +37,7 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RunPython(fix_problem_count_bugs, reverse_code=migrations.RunPython.noop) + migrations.RunPython( + fix_problem_count_bugs, reverse_code=migrations.RunPython.noop + ) ] diff --git a/problem/migrations/0012_auto_20180501_0436.py b/problem/migrations/0012_auto_20180501_0436.py index 77b5c7811..5eb8707f9 100644 --- a/problem/migrations/0012_auto_20180501_0436.py +++ b/problem/migrations/0012_auto_20180501_0436.py @@ -8,7 +8,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0011_fix_problem_ac_count'), ] @@ -22,7 +21,11 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='problem', name='contest', - field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'), + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), ), migrations.AlterField( model_name='problem', diff --git a/problem/migrations/0013_problem_io_mode.py b/problem/migrations/0013_problem_io_mode.py index 17136c876..cc911e92a 100644 --- a/problem/migrations/0013_problem_io_mode.py +++ b/problem/migrations/0013_problem_io_mode.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0012_auto_20180501_0436'), ] @@ -15,6 +14,8 @@ class Migration(migrations.Migration): migrations.AddField( model_name='problem', name='io_mode', - field=django.contrib.postgres.fields.jsonb.JSONField(default=problem.models._default_io_mode), + field=django.contrib.postgres.fields.jsonb.JSONField( + default=problem.models._default_io_mode + ), ), ] diff --git a/problem/migrations/0014_problem_share_submission.py b/problem/migrations/0014_problem_share_submission.py index c764d7ce3..2c19727af 100644 --- a/problem/migrations/0014_problem_share_submission.py +++ b/problem/migrations/0014_problem_share_submission.py @@ -4,7 +4,6 @@ class Migration(migrations.Migration): - dependencies = [ ('problem', '0013_problem_io_mode'), ] diff --git a/problem/models.py b/problem/models.py index c16aeadae..77d4cd6ef 100644 --- a/problem/models.py +++ b/problem/models.py @@ -11,27 +11,31 @@ class ProblemTag(models.Model): name = models.TextField() class Meta: - db_table = "problem_tag" + db_table = 'problem_tag' class ProblemRuleType(Choices): - ACM = "ACM" - OI = "OI" + ACM = 'ACM' + OI = 'OI' class ProblemDifficulty(object): - High = "High" - Mid = "Mid" - Low = "Low" + High = 'High' + Mid = 'Mid' + Low = 'Low' class ProblemIOMode(Choices): - standard = "Standard IO" - file = "File IO" + standard = 'Standard IO' + file = 'File IO' def _default_io_mode(): - return {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"} + return { + 'io_mode': ProblemIOMode.standard, + 'input': 'input.txt', + 'output': 'output.txt', + } class Problem(models.Model): @@ -83,14 +87,14 @@ class Problem(models.Model): share_submission = models.BooleanField(default=False) class Meta: - db_table = "problem" - unique_together = (("_id", "contest"),) - ordering = ("create_time",) + db_table = 'problem' + unique_together = (('_id', 'contest'),) + ordering = ('create_time',) def add_submission_number(self): - self.submission_number = models.F("submission_number") + 1 - self.save(update_fields=["submission_number"]) + self.submission_number = models.F('submission_number') + 1 + self.save(update_fields=['submission_number']) def add_ac_number(self): - self.accepted_number = models.F("accepted_number") + 1 - self.save(update_fields=["accepted_number"]) + self.accepted_number = models.F('accepted_number') + 1 + self.save(update_fields=['accepted_number']) diff --git a/problem/serializers.py b/problem/serializers.py index 2c09a6c98..c88824a40 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -5,7 +5,11 @@ from options.options import SysOptions from utils.api import UsernameSerializer, serializers from utils.constants import Difficulty -from utils.serializers import LanguageNameMultiChoiceField, SPJLanguageNameChoiceField, LanguageNameChoiceField +from utils.serializers import ( + LanguageNameMultiChoiceField, + SPJLanguageNameChoiceField, + LanguageNameChoiceField, +) from .models import Problem, ProblemRuleType, ProblemTag, ProblemIOMode from .utils import parse_problem_template @@ -37,11 +41,11 @@ class ProblemIOModeSerializer(serializers.Serializer): output = serializers.CharField() def validate(self, attrs): - if attrs["input"] == attrs["output"]: - raise serializers.ValidationError("Invalid io mode") - for item in (attrs["input"], attrs["output"]): - if not re.match("^[a-zA-Z0-9.]+$", item): - raise serializers.ValidationError("Invalid io file name format") + if attrs['input'] == attrs['output']: + raise serializers.ValidationError('Invalid io mode') + for item in (attrs['input'], attrs['output']): + if not re.match('^[a-zA-Z0-9.]+$', item): + raise serializers.ValidationError('Invalid io file name format') return attrs @@ -53,12 +57,16 @@ class CreateOrEditProblemSerializer(serializers.Serializer): output_description = serializers.CharField() samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False) test_case_id = serializers.CharField(max_length=32) - test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=True) + test_case_score = serializers.ListField( + child=CreateTestCaseScoreSerializer(), allow_empty=True + ) time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60) memory_limit = serializers.IntegerField(min_value=1, max_value=1024) languages = LanguageNameMultiChoiceField() template = serializers.DictField(child=serializers.CharField(min_length=1)) - rule_type = serializers.ChoiceField(choices=[ProblemRuleType.ACM, ProblemRuleType.OI]) + rule_type = serializers.ChoiceField( + choices=[ProblemRuleType.ACM, ProblemRuleType.OI] + ) io_mode = ProblemIOModeSerializer() spj = serializers.BooleanField() spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True) @@ -66,7 +74,9 @@ class CreateOrEditProblemSerializer(serializers.Serializer): spj_compile_ok = serializers.BooleanField(default=False) visible = serializers.BooleanField() difficulty = serializers.ChoiceField(choices=Difficulty.choices()) - tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) + tags = serializers.ListField( + child=serializers.CharField(max_length=32), allow_empty=False + ) hint = serializers.CharField(allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) share_submission = serializers.BooleanField() @@ -92,7 +102,7 @@ class EditContestProblemSerializer(CreateOrEditProblemSerializer): class TagSerializer(serializers.ModelSerializer): class Meta: model = ProblemTag - fields = "__all__" + fields = '__all__' class CompileSPJSerializer(serializers.Serializer): @@ -101,39 +111,56 @@ class CompileSPJSerializer(serializers.Serializer): class BaseProblemSerializer(serializers.ModelSerializer): - tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True) + tags = serializers.SlugRelatedField(many=True, slug_field='name', read_only=True) created_by = UsernameSerializer() def get_public_template(self, obj): ret = {} for lang, code in obj.template.items(): - ret[lang] = parse_problem_template(code)["template"] + ret[lang] = parse_problem_template(code)['template'] return ret class ProblemAdminSerializer(BaseProblemSerializer): class Meta: model = Problem - fields = "__all__" + fields = '__all__' class ProblemSerializer(BaseProblemSerializer): - template = serializers.SerializerMethodField("get_public_template") + template = serializers.SerializerMethodField('get_public_template') class Meta: model = Problem - exclude = ("test_case_score", "test_case_id", "visible", "is_public", - "spj_code", "spj_version", "spj_compile_ok") + exclude = ( + 'test_case_score', + 'test_case_id', + 'visible', + 'is_public', + 'spj_code', + 'spj_version', + 'spj_compile_ok', + ) class ProblemSafeSerializer(BaseProblemSerializer): - template = serializers.SerializerMethodField("get_public_template") + template = serializers.SerializerMethodField('get_public_template') class Meta: model = Problem - exclude = ("test_case_score", "test_case_id", "visible", "is_public", - "spj_code", "spj_version", "spj_compile_ok", - "difficulty", "submission_number", "accepted_number", "statistic_info") + exclude = ( + 'test_case_score', + 'test_case_id', + 'visible', + 'is_public', + 'spj_code', + 'spj_version', + 'spj_compile_ok', + 'difficulty', + 'submission_number', + 'accepted_number', + 'statistic_info', + ) class ContestProblemMakePublicSerializer(serializers.Serializer): @@ -151,13 +178,13 @@ class ExportProblemSerializer(serializers.ModelSerializer): spj = serializers.SerializerMethodField() template = serializers.SerializerMethodField() source = serializers.SerializerMethodField() - tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True) + tags = serializers.SlugRelatedField(many=True, slug_field='name', read_only=True) def get_display_id(self, obj): return obj._id def _html_format_value(self, value): - return {"format": "html", "value": value} + return {'format': 'html', 'value': value} def get_description(self, obj): return self._html_format_value(obj.description) @@ -172,13 +199,17 @@ def get_hint(self, obj): return self._html_format_value(obj.hint) def get_test_case_score(self, obj): - return [{"score": item["score"] if obj.rule_type == ProblemRuleType.OI else 100, - "input_name": item["input_name"], "output_name": item["output_name"]} - for item in obj.test_case_score] + return [ + { + 'score': item['score'] if obj.rule_type == ProblemRuleType.OI else 100, + 'input_name': item['input_name'], + 'output_name': item['output_name'], + } + for item in obj.test_case_score + ] def get_spj(self, obj): - return {"code": obj.spj_code, - "language": obj.spj_language} if obj.spj else None + return {'code': obj.spj_code, 'language': obj.spj_language} if obj.spj else None def get_template(self, obj): ret = {} @@ -187,14 +218,28 @@ def get_template(self, obj): return ret def get_source(self, obj): - return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}" + return obj.source or f'{SysOptions.website_name} {SysOptions.website_base_url}' class Meta: model = Problem - fields = ("display_id", "title", "description", "tags", - "input_description", "output_description", - "test_case_score", "hint", "time_limit", "memory_limit", "samples", - "template", "spj", "rule_type", "source", "template") + fields = ( + 'display_id', + 'title', + 'description', + 'tags', + 'input_description', + 'output_description', + 'test_case_score', + 'hint', + 'time_limit', + 'memory_limit', + 'samples', + 'template', + 'spj', + 'rule_type', + 'source', + 'template', + ) class AddContestProblemSerializer(serializers.Serializer): @@ -204,7 +249,9 @@ class AddContestProblemSerializer(serializers.Serializer): class ExportProblemRequestSerialzier(serializers.Serializer): - problem_id = serializers.ListField(child=serializers.IntegerField(), allow_empty=False) + problem_id = serializers.ListField( + child=serializers.IntegerField(), allow_empty=False + ) class UploadProblemForm(forms.Form): @@ -212,7 +259,7 @@ class UploadProblemForm(forms.Form): class FormatValueSerializer(serializers.Serializer): - format = serializers.ChoiceField(choices=["html", "markdown"]) + format = serializers.ChoiceField(choices=['html', 'markdown']) value = serializers.CharField(allow_blank=True) @@ -245,7 +292,9 @@ class ImportProblemSerializer(serializers.Serializer): input_description = FormatValueSerializer() output_description = FormatValueSerializer() hint = FormatValueSerializer() - test_case_score = serializers.ListField(child=TestCaseScoreSerializer(), allow_null=True) + test_case_score = serializers.ListField( + child=TestCaseScoreSerializer(), allow_null=True + ) time_limit = serializers.IntegerField(min_value=1, max_value=60000) memory_limit = serializers.IntegerField(min_value=1, max_value=10240) samples = serializers.ListField(child=CreateSampleSerializer()) @@ -259,7 +308,7 @@ class ImportProblemSerializer(serializers.Serializer): class FPSProblemSerializer(serializers.Serializer): class UnitSerializer(serializers.Serializer): - unit = serializers.ChoiceField(choices=["MB", "s", "ms"]) + unit = serializers.ChoiceField(choices=['MB', 's', 'ms']) value = serializers.IntegerField(min_value=1, max_value=60000) title = serializers.CharField(max_length=128) @@ -272,6 +321,12 @@ class UnitSerializer(serializers.Serializer): samples = serializers.ListField(child=CreateSampleSerializer()) source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True) spj = SPJSerializer(allow_null=True) - template = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) - append = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) - prepend = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True) + template = serializers.ListField( + child=serializers.DictField(), allow_empty=True, allow_null=True + ) + append = serializers.ListField( + child=serializers.DictField(), allow_empty=True, allow_null=True + ) + prepend = serializers.ListField( + child=serializers.DictField(), allow_empty=True, allow_null=True + ) diff --git a/problem/tests.py b/problem/tests.py index 7074d2f4d..f52d5f488 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -17,43 +17,72 @@ from .views.admin import TestCaseAPI from .utils import parse_problem_template -DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "

test

", "input_description": "test", - "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low", - "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {}, - "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C", - "spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85", - "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, - "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", - "input_size": 0, "score": 0}], - "io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"}, - "share_submission": False, - "rule_type": "ACM", "hint": "

test

", "source": "test"} +DEFAULT_PROBLEM_DATA = { + '_id': 'A-110', + 'title': 'test', + 'description': '

test

', + 'input_description': 'test', + 'output_description': 'test', + 'time_limit': 1000, + 'memory_limit': 256, + 'difficulty': 'Low', + 'visible': True, + 'tags': ['test'], + 'languages': ['C', 'C++', 'Java', 'Python2'], + 'template': {}, + 'samples': [{'input': 'test', 'output': 'test'}], + 'spj': False, + 'spj_language': 'C', + 'spj_code': '', + 'spj_compile_ok': True, + 'test_case_id': '499b26290cc7994e0b497212e842ea85', + 'test_case_score': [ + { + 'output_name': '1.out', + 'input_name': '1.in', + 'output_size': 0, + 'stripped_output_md5': 'd41d8cd98f00b204e9800998ecf8427e', + 'input_size': 0, + 'score': 0, + } + ], + 'io_mode': { + 'io_mode': ProblemIOMode.standard, + 'input': 'input.txt', + 'output': 'output.txt', + }, + 'share_submission': False, + 'rule_type': 'ACM', + 'hint': '

test

', + 'source': 'test', +} class ProblemCreateTestBase(APITestCase): @staticmethod def add_problem(problem_data, created_by): data = copy.deepcopy(problem_data) - if data["spj"]: - if not data["spj_language"] or not data["spj_code"]: - raise ValueError("Invalid spj") - data["spj_version"] = hashlib.md5( - (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + if data['spj']: + if not data['spj_language'] or not data['spj_code']: + raise ValueError('Invalid spj') + data['spj_version'] = hashlib.md5( + (data['spj_language'] + ':' + data['spj_code']).encode('utf-8') + ).hexdigest() else: - data["spj_language"] = None - data["spj_code"] = None - if data["rule_type"] == ProblemRuleType.OI: + data['spj_language'] = None + data['spj_code'] = None + if data['rule_type'] == ProblemRuleType.OI: total_score = 0 - for item in data["test_case_score"]: - if item["score"] <= 0: - raise ValueError("invalid score") + for item in data['test_case_score']: + if item['score'] <= 0: + raise ValueError('invalid score') else: - total_score += item["score"] - data["total_score"] = total_score - data["created_by"] = created_by - tags = data.pop("tags") + total_score += item['score'] + data['total_score'] = total_score + data['created_by'] = created_by + tags = data.pop('tags') - data["languages"] = list(data["languages"]) + data['languages'] = list(data['languages']) problem = Problem.objects.create(**data) @@ -68,72 +97,85 @@ def add_problem(problem_data, created_by): class ProblemTagListAPITest(APITestCase): def test_get_tag_list(self): - ProblemTag.objects.create(name="name1") - ProblemTag.objects.create(name="name2") - resp = self.client.get(self.reverse("problem_tag_list_api")) + ProblemTag.objects.create(name='name1') + ProblemTag.objects.create(name='name2') + resp = self.client.get(self.reverse('problem_tag_list_api')) self.assertSuccess(resp) class TestCaseUploadAPITest(APITestCase): def setUp(self): self.api = TestCaseAPI() - self.url = self.reverse("test_case_api") + self.url = self.reverse('test_case_api') self.create_super_admin() def test_filter_file_name(self): - self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False), - ["1.in", "1.out"]) - self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), []) - - self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"]) - self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), []) + self.assertEqual( + self.api.filter_name_list( + ['1.in', '1.out', '2.in', '.DS_Store'], spj=False + ), + ['1.in', '1.out'], + ) + self.assertEqual(self.api.filter_name_list(['2.in', '2.out'], spj=False), []) + + self.assertEqual( + self.api.filter_name_list(['1.in', '1.out', '2.in'], spj=True), + ['1.in', '2.in'], + ) + self.assertEqual(self.api.filter_name_list(['2.in', '3.in'], spj=True), []) def make_test_case_zip(self): - base_dir = os.path.join("/tmp", "test_case") + base_dir = os.path.join('/tmp', 'test_case') shutil.rmtree(base_dir, ignore_errors=True) os.mkdir(base_dir) - file_names = ["1.in", "1.out", "2.in", ".DS_Store"] + file_names = ['1.in', '1.out', '2.in', '.DS_Store'] for item in file_names: - with open(os.path.join(base_dir, item), "w", encoding="utf-8") as f: - f.write(item + "\n" + item + "\r\n" + "end") - zip_file = os.path.join(base_dir, "test_case.zip") - with ZipFile(os.path.join(base_dir, "test_case.zip"), "w") as f: + with open(os.path.join(base_dir, item), 'w', encoding='utf-8') as f: + f.write(item + '\n' + item + '\r\n' + 'end') + zip_file = os.path.join(base_dir, 'test_case.zip') + with ZipFile(os.path.join(base_dir, 'test_case.zip'), 'w') as f: for item in file_names: f.write(os.path.join(base_dir, item), item) return zip_file def test_upload_spj_test_case_zip(self): - with open(self.make_test_case_zip(), "rb") as f: - resp = self.client.post(self.url, - data={"spj": "true", "file": f}, format="multipart") + with open(self.make_test_case_zip(), 'rb') as f: + resp = self.client.post( + self.url, data={'spj': 'true', 'file': f}, format='multipart' + ) self.assertSuccess(resp) - data = resp.data["data"] - self.assertEqual(data["spj"], True) - test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"]) + data = resp.data['data'] + self.assertEqual(data['spj'], True) + test_case_dir = os.path.join(settings.TEST_CASE_DIR, data['id']) self.assertTrue(os.path.exists(test_case_dir)) - for item in data["info"]: - name = item["input_name"] - with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f: - self.assertEqual(f.read(), name + "\n" + name + "\n" + "end") + for item in data['info']: + name = item['input_name'] + with open( + os.path.join(test_case_dir, name), 'r', encoding='utf-8' + ) as f: + self.assertEqual(f.read(), name + '\n' + name + '\n' + 'end') def test_upload_test_case_zip(self): - with open(self.make_test_case_zip(), "rb") as f: - resp = self.client.post(self.url, - data={"spj": "false", "file": f}, format="multipart") + with open(self.make_test_case_zip(), 'rb') as f: + resp = self.client.post( + self.url, data={'spj': 'false', 'file': f}, format='multipart' + ) self.assertSuccess(resp) - data = resp.data["data"] - self.assertEqual(data["spj"], False) - test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"]) + data = resp.data['data'] + self.assertEqual(data['spj'], False) + test_case_dir = os.path.join(settings.TEST_CASE_DIR, data['id']) self.assertTrue(os.path.exists(test_case_dir)) - for item in data["info"]: - name = item["input_name"] - with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f: - self.assertEqual(f.read(), name + "\n" + name + "\n" + "end") + for item in data['info']: + name = item['input_name'] + with open( + os.path.join(test_case_dir, name), 'r', encoding='utf-8' + ) as f: + self.assertEqual(f.read(), name + '\n' + name + '\n' + 'end') class ProblemAdminAPITest(APITestCase): def setUp(self): - self.url = self.reverse("problem_admin_api") + self.url = self.reverse('problem_admin_api') self.create_super_admin() self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA) @@ -146,16 +188,16 @@ def test_duplicate_display_id(self): self.test_create_problem() resp = self.client.post(self.url, data=self.data) - self.assertFailed(resp, "Display ID already exists") + self.assertFailed(resp, 'Display ID already exists') def test_spj(self): data = copy.deepcopy(self.data) - data["spj"] = True + data['spj'] = True resp = self.client.post(self.url, data) - self.assertFailed(resp, "Invalid spj") + self.assertFailed(resp, 'Invalid spj') - data["spj_code"] = "test" + data['spj_code'] = 'test' resp = self.client.post(self.url, data=data) self.assertSuccess(resp) @@ -165,122 +207,128 @@ def test_get_problem(self): self.assertSuccess(resp) def test_get_one_problem(self): - problem_id = self.test_create_problem().data["data"]["id"] - resp = self.client.get(self.url + "?id=" + str(problem_id)) + problem_id = self.test_create_problem().data['data']['id'] + resp = self.client.get(self.url + '?id=' + str(problem_id)) self.assertSuccess(resp) def test_edit_problem(self): - problem_id = self.test_create_problem().data["data"]["id"] + problem_id = self.test_create_problem().data['data']['id'] data = copy.deepcopy(self.data) - data["id"] = problem_id + data['id'] = problem_id resp = self.client.put(self.url, data=data) self.assertSuccess(resp) class ProblemAPITest(ProblemCreateTestBase): def setUp(self): - self.url = self.reverse("problem_api") + self.url = self.reverse('problem_api') admin = self.create_admin(login=False) self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.create_user("test", "test123") + self.create_user('test', 'test123') def test_get_problem_list(self): - resp = self.client.get(f"{self.url}?limit=10") + resp = self.client.get(f'{self.url}?limit=10') self.assertSuccess(resp) def get_one_problem(self): - resp = self.client.get(self.url + "?id=" + self.problem._id) + resp = self.client.get(self.url + '?id=' + self.problem._id) self.assertSuccess(resp) class ContestProblemAdminTest(APITestCase): def setUp(self): - self.url = self.reverse("contest_problem_admin_api") + self.url = self.reverse('contest_problem_admin_api') self.create_admin() - self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"] + self.contest = self.client.post( + self.reverse('contest_admin_api'), data=DEFAULT_CONTEST_DATA + ).data['data'] def test_create_contest_problem(self): data = copy.deepcopy(DEFAULT_PROBLEM_DATA) - data["contest_id"] = self.contest["id"] + data['contest_id'] = self.contest['id'] resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - return resp.data["data"] + return resp.data['data'] def test_get_contest_problem(self): self.test_create_contest_problem() - contest_id = self.contest["id"] - resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) + contest_id = self.contest['id'] + resp = self.client.get(self.url + '?contest_id=' + str(contest_id)) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]["results"]), 1) + self.assertEqual(len(resp.data['data']['results']), 1) def test_get_one_contest_problem(self): contest_problem = self.test_create_contest_problem() - contest_id = self.contest["id"] - problem_id = contest_problem["id"] - resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") + contest_id = self.contest['id'] + problem_id = contest_problem['id'] + resp = self.client.get(f'{self.url}?contest_id={contest_id}&id={problem_id}') self.assertSuccess(resp) class ContestProblemTest(ProblemCreateTestBase): def setUp(self): admin = self.create_admin() - url = self.reverse("contest_admin_api") + url = self.reverse('contest_admin_api') contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) - contest_data["password"] = "" - contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) - self.contest = self.client.post(url, data=contest_data).data["data"] + contest_data['password'] = '' + contest_data['start_time'] = contest_data['start_time'] + timedelta(hours=1) + self.contest = self.client.post(url, data=contest_data).data['data'] self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.problem.contest_id = self.contest["id"] + self.problem.contest_id = self.contest['id'] self.problem.save() - self.url = self.reverse("contest_problem_api") + self.url = self.reverse('contest_problem_api') def test_admin_get_contest_problem_list(self): - contest_id = self.contest["id"] - resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) + contest_id = self.contest['id'] + resp = self.client.get(self.url + '?contest_id=' + str(contest_id)) self.assertSuccess(resp) - self.assertEqual(len(resp.data["data"]), 1) + self.assertEqual(len(resp.data['data']), 1) def test_admin_get_one_contest_problem(self): - contest_id = self.contest["id"] + contest_id = self.contest['id'] problem_id = self.problem._id - resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id)) + resp = self.client.get( + '{}?contest_id={}&problem_id={}'.format(self.url, contest_id, problem_id) + ) self.assertSuccess(resp) def test_regular_user_get_not_started_contest_problem(self): - self.create_user("test", "test123") - resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"])) - self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."}) + self.create_user('test', 'test123') + resp = self.client.get(self.url + '?contest_id=' + str(self.contest['id'])) + self.assertDictEqual( + resp.data, {'error': 'error', 'data': 'Contest has not started yet.'} + ) def test_reguar_user_get_started_contest_problem(self): - self.create_user("test", "test123") + self.create_user('test', 'test123') contest = Contest.objects.first() contest.start_time = contest.start_time - timedelta(hours=1) contest.save() - resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"])) + resp = self.client.get(self.url + '?contest_id=' + str(self.contest['id'])) self.assertSuccess(resp) class AddProblemFromPublicProblemAPITest(ProblemCreateTestBase): def setUp(self): admin = self.create_admin() - url = self.reverse("contest_admin_api") + url = self.reverse('contest_admin_api') contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) - contest_data["password"] = "" - contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) - self.contest = self.client.post(url, data=contest_data).data["data"] + contest_data['password'] = '' + contest_data['start_time'] = contest_data['start_time'] + timedelta(hours=1) + self.contest = self.client.post(url, data=contest_data).data['data'] self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) - self.url = self.reverse("add_contest_problem_from_public_api") + self.url = self.reverse('add_contest_problem_from_public_api') self.data = { - "display_id": "1000", - "contest_id": self.contest["id"], - "problem_id": self.problem.id + 'display_id': '1000', + 'contest_id': self.contest['id'], + 'problem_id': self.problem.id, } def test_add_contest_problem(self): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) self.assertTrue(Problem.objects.all().exists()) - self.assertTrue(Problem.objects.filter(contest_id=self.contest["id"]).exists()) + self.assertTrue(Problem.objects.filter(contest_id=self.contest['id']).exists()) class ParseProblemTemplateTest(APITestCase): @@ -300,9 +348,9 @@ def test_parse(self): """ ret = parse_problem_template(template_str) - self.assertEqual(ret["prepend"], "aaa\n") - self.assertEqual(ret["template"], "bbb\n") - self.assertEqual(ret["append"], "ccc\n") + self.assertEqual(ret['prepend'], 'aaa\n') + self.assertEqual(ret['template'], 'bbb\n') + self.assertEqual(ret['append'], 'ccc\n') def test_parse1(self): template_str = """ @@ -319,6 +367,6 @@ def test_parse1(self): """ ret = parse_problem_template(template_str) - self.assertEqual(ret["prepend"], "aaa\n") - self.assertEqual(ret["template"], "") - self.assertEqual(ret["append"], "ccc\n") + self.assertEqual(ret['prepend'], 'aaa\n') + self.assertEqual(ret['template'], '') + self.assertEqual(ret['append'], 'ccc\n') diff --git a/problem/urls/admin.py b/problem/urls/admin.py index e3a921fca..7e0552c7c 100644 --- a/problem/urls/admin.py +++ b/problem/urls/admin.py @@ -1,17 +1,37 @@ from django.conf.urls import url -from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView, - CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI, - FPSProblemImport) +from ..views.admin import ( + ContestProblemAPI, + ProblemAPI, + TestCaseAPI, + MakeContestProblemPublicAPIView, + CompileSPJAPI, + AddContestProblemAPI, + ExportProblemAPI, + ImportProblemAPI, + FPSProblemImport, +) urlpatterns = [ - url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"), - url(r"^compile_spj/?$", CompileSPJAPI.as_view(), name="compile_spj"), - url(r"^problem/?$", ProblemAPI.as_view(), name="problem_admin_api"), - url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"), - url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"), - url(r"^contest/add_problem_from_public/?$", AddContestProblemAPI.as_view(), name="add_contest_problem_from_public_api"), - url(r"^export_problem/?$", ExportProblemAPI.as_view(), name="export_problem_api"), - url(r"^import_problem/?$", ImportProblemAPI.as_view(), name="import_problem_api"), - url(r"^import_fps/?$", FPSProblemImport.as_view(), name="fps_problem_api"), + url(r'^test_case/?$', TestCaseAPI.as_view(), name='test_case_api'), + url(r'^compile_spj/?$', CompileSPJAPI.as_view(), name='compile_spj'), + url(r'^problem/?$', ProblemAPI.as_view(), name='problem_admin_api'), + url( + r'^contest/problem/?$', + ContestProblemAPI.as_view(), + name='contest_problem_admin_api', + ), + url( + r'^contest_problem/make_public/?$', + MakeContestProblemPublicAPIView.as_view(), + name='make_public_api', + ), + url( + r'^contest/add_problem_from_public/?$', + AddContestProblemAPI.as_view(), + name='add_contest_problem_from_public_api', + ), + url(r'^export_problem/?$', ExportProblemAPI.as_view(), name='export_problem_api'), + url(r'^import_problem/?$', ImportProblemAPI.as_view(), name='import_problem_api'), + url(r'^import_fps/?$', FPSProblemImport.as_view(), name='fps_problem_api'), ] diff --git a/problem/urls/oj.py b/problem/urls/oj.py index f7cd3ae3a..bae1f7cf6 100644 --- a/problem/urls/oj.py +++ b/problem/urls/oj.py @@ -3,8 +3,10 @@ from ..views.oj import ProblemTagAPI, ProblemAPI, ContestProblemAPI, PickOneAPI urlpatterns = [ - url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api"), - url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"), - url(r"^pickone/?$", PickOneAPI.as_view(), name="pick_one_api"), - url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api"), + url(r'^problem/tags/?$', ProblemTagAPI.as_view(), name='problem_tag_list_api'), + url(r'^problem/?$', ProblemAPI.as_view(), name='problem_api'), + url(r'^pickone/?$', PickOneAPI.as_view(), name='pick_one_api'), + url( + r'^contest/problem/?$', ContestProblemAPI.as_view(), name='contest_problem_api' + ), ] diff --git a/problem/utils.py b/problem/utils.py index c53026394..3ac0bb33d 100644 --- a/problem/utils.py +++ b/problem/utils.py @@ -17,12 +17,14 @@ @lru_cache(maxsize=100) def parse_problem_template(template_str): - prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) - template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str) - append = re.findall(r"//APPEND BEGIN\n([\s\S]+?)//APPEND END", template_str) - return {"prepend": prepend[0] if prepend else "", - "template": template[0] if template else "", - "append": append[0] if append else ""} + prepend = re.findall(r'//PREPEND BEGIN\n([\s\S]+?)//PREPEND END', template_str) + template = re.findall(r'//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END', template_str) + append = re.findall(r'//APPEND BEGIN\n([\s\S]+?)//APPEND END', template_str) + return { + 'prepend': prepend[0] if prepend else '', + 'template': template[0] if template else '', + 'append': append[0] if append else '', + } @lru_cache(maxsize=100) diff --git a/problem/views/admin.py b/problem/views/admin.py index 5ce9413b2..7f3988338 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -1,6 +1,7 @@ import hashlib import json import os + # import shutil import tempfile import zipfile @@ -22,25 +23,35 @@ from utils.shortcuts import rand_str, natural_sort_key from utils.tasks import delete_files from ..models import Problem, ProblemRuleType, ProblemTag -from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer, - CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, - ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, - AddContestProblemSerializer, ExportProblemSerializer, - ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer, - FPSProblemSerializer) +from ..serializers import ( + CreateContestProblemSerializer, + CompileSPJSerializer, + CreateProblemSerializer, + EditProblemSerializer, + EditContestProblemSerializer, + ProblemAdminSerializer, + TestCaseUploadForm, + ContestProblemMakePublicSerializer, + AddContestProblemSerializer, + ExportProblemSerializer, + ExportProblemRequestSerialzier, + UploadProblemForm, + ImportProblemSerializer, + FPSProblemSerializer, +) from ..utils import TEMPLATE_BASE, build_problem_template class TestCaseZipProcessor(object): - def process_zip(self, uploaded_zip_file, spj, dir=""): + def process_zip(self, uploaded_zip_file, spj, dir=''): try: - zip_file = zipfile.ZipFile(uploaded_zip_file, "r") + zip_file = zipfile.ZipFile(uploaded_zip_file, 'r') except zipfile.BadZipFile: - raise APIError("Bad zip file") + raise APIError('Bad zip file') name_list = zip_file.namelist() test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir) if not test_case_list: - raise APIError("Empty file") + raise APIError('Empty file') test_case_id = rand_str() test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) @@ -51,34 +62,36 @@ def process_zip(self, uploaded_zip_file, spj, dir=""): md5_cache = {} for item in test_case_list: - with open(os.path.join(test_case_dir, item), "wb") as f: - content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n") + with open(os.path.join(test_case_dir, item), 'wb') as f: + content = zip_file.read(f'{dir}{item}').replace(b'\r\n', b'\n') size_cache[item] = len(content) - if item.endswith(".out"): + if item.endswith('.out'): md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() f.write(content) - test_case_info = {"spj": spj, "test_cases": {}} + test_case_info = {'spj': spj, 'test_cases': {}} info = [] if spj: for index, item in enumerate(test_case_list): - data = {"input_name": item, "input_size": size_cache[item]} + data = {'input_name': item, 'input_size': size_cache[item]} info.append(data) - test_case_info["test_cases"][str(index + 1)] = data + test_case_info['test_cases'][str(index + 1)] = data else: # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list): - data = {"stripped_output_md5": md5_cache[item[1]], - "input_size": size_cache[item[0]], - "output_size": size_cache[item[1]], - "input_name": item[0], - "output_name": item[1]} + data = { + 'stripped_output_md5': md5_cache[item[1]], + 'input_size': size_cache[item[0]], + 'output_size': size_cache[item[1]], + 'input_name': item[0], + 'output_name': item[1], + } info.append(data) - test_case_info["test_cases"][str(index + 1)] = data + test_case_info['test_cases'][str(index + 1)] = data - with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: + with open(os.path.join(test_case_dir, 'info'), 'w', encoding='utf-8') as f: f.write(json.dumps(test_case_info, indent=4)) for item in os.listdir(test_case_dir): @@ -86,13 +99,13 @@ def process_zip(self, uploaded_zip_file, spj, dir=""): return info, test_case_id - def filter_name_list(self, name_list, spj, dir=""): + def filter_name_list(self, name_list, spj, dir=''): ret = [] prefix = 1 if spj: while True: - in_name = f"{prefix}.in" - if f"{dir}{in_name}" in name_list: + in_name = f'{prefix}.in' + if f'{dir}{in_name}' in name_list: ret.append(in_name) prefix += 1 continue @@ -100,9 +113,9 @@ def filter_name_list(self, name_list, spj, dir=""): return sorted(ret, key=natural_sort_key) else: while True: - in_name = f"{prefix}.in" - out_name = f"{prefix}.out" - if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list: + in_name = f'{prefix}.in' + out_name = f'{prefix}.out' + if f'{dir}{in_name}' in name_list and f'{dir}{out_name}' in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 @@ -115,13 +128,13 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor): request_parsers = () def get(self, request): - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if not problem_id: - return self.error("Parameter error, problem_id is required") + return self.error('Parameter error, problem_id is required') try: problem = Problem.objects.get(id=problem_id) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') if problem.contest: ensure_created_by(problem.contest, request.user) @@ -130,34 +143,37 @@ def get(self, request): test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): - return self.error("Test case does not exists") + return self.error('Test case does not exists') name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj) - name_list.append("info") - file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") - with zipfile.ZipFile(file_name, "w") as file: + name_list.append('info') + file_name = os.path.join(test_case_dir, problem.test_case_id + '.zip') + with zipfile.ZipFile(file_name, 'w') as file: for test_case in name_list: - file.write(f"{test_case_dir}/{test_case}", test_case) - response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), - content_type="application/octet-stream") - - response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip" - response["Content-Length"] = os.path.getsize(file_name) + file.write(f'{test_case_dir}/{test_case}', test_case) + response = StreamingHttpResponse( + FileWrapper(open(file_name, 'rb')), content_type='application/octet-stream' + ) + + response[ + 'Content-Disposition' + ] = f'attachment; filename=problem_{problem.id}_test_cases.zip' + response['Content-Length'] = os.path.getsize(file_name) return response def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): - spj = form.cleaned_data["spj"] == "true" - file = form.cleaned_data["file"] + spj = form.cleaned_data['spj'] == 'true' + file = form.cleaned_data['file'] else: - return self.error("Upload failed") - zip_file = f"/tmp/{rand_str()}.zip" - with open(zip_file, "wb") as f: + return self.error('Upload failed') + zip_file = f'/tmp/{rand_str()}.zip' + with open(zip_file, 'wb') as f: for chunk in file: f.write(chunk) info, test_case_id = self.process_zip(zip_file, spj=spj) os.remove(zip_file) - return self.success({"id": test_case_id, "info": info, "spj": spj}) + return self.success({'id': test_case_id, 'info': info, 'spj': spj}) class CompileSPJAPI(APIView): @@ -165,7 +181,9 @@ class CompileSPJAPI(APIView): def post(self, request): data = request.data spj_version = rand_str(8) - error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj() + error = SPJCompiler( + data['spj_code'], spj_version, data['spj_language'] + ).compile_spj() if error: return self.error(error) else: @@ -175,25 +193,26 @@ def post(self, request): class ProblemBase(APIView): def common_checks(self, request): data = request.data - if data["spj"]: - if not data["spj_language"] or not data["spj_code"]: - return "Invalid spj" - if not data["spj_compile_ok"]: - return "SPJ code must be compiled successfully" - data["spj_version"] = hashlib.md5( - (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + if data['spj']: + if not data['spj_language'] or not data['spj_code']: + return 'Invalid spj' + if not data['spj_compile_ok']: + return 'SPJ code must be compiled successfully' + data['spj_version'] = hashlib.md5( + (data['spj_language'] + ':' + data['spj_code']).encode('utf-8') + ).hexdigest() else: - data["spj_language"] = None - data["spj_code"] = None - if data["rule_type"] == ProblemRuleType.OI: + data['spj_language'] = None + data['spj_code'] = None + if data['rule_type'] == ProblemRuleType.OI: total_score = 0 - for item in data["test_case_score"]: - if item["score"] <= 0: - return "Invalid score" + for item in data['test_case_score']: + if item['score'] <= 0: + return 'Invalid score' else: - total_score += item["score"] - data["total_score"] = total_score - data["languages"] = list(data["languages"]) + total_score += item['score'] + data['total_score'] = total_score + data['languages'] = list(data['languages']) class ProblemAPI(ProblemBase): @@ -201,19 +220,19 @@ class ProblemAPI(ProblemBase): @validate_serializer(CreateProblemSerializer) def post(self, request): data = request.data - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") + return self.error('Display ID is required') if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists(): - return self.error("Display ID already exists") + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["created_by"] = request.user + tags = data.pop('tags') + data['created_by'] = request.user problem = Problem.objects.create(**data) for item in tags: @@ -226,8 +245,8 @@ def post(self, request): @problem_permission_required def get(self, request): - problem_id = request.GET.get("id") - rule_type = request.GET.get("rule_type") + problem_id = request.GET.get('id') + rule_type = request.GET.get('rule_type') user = request.user if problem_id: try: @@ -235,46 +254,56 @@ def get(self, request): ensure_created_by(problem, request.user) return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time") + problems = Problem.objects.filter(contest_id__isnull=True).order_by( + '-create_time' + ) if rule_type: if rule_type not in ProblemRuleType.choices(): - return self.error("Invalid rule_type") + return self.error('Invalid rule_type') else: problems = problems.filter(rule_type=rule_type) - keyword = request.GET.get("keyword", "").strip() + keyword = request.GET.get('keyword', '').strip() if keyword: - problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword)) + problems = problems.filter( + Q(title__icontains=keyword) | Q(_id__icontains=keyword) + ) if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=user) - return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) + return self.success( + self.paginate_data(request, problems, ProblemAdminSerializer) + ) @problem_permission_required @validate_serializer(EditProblemSerializer) def put(self, request): data = request.data - problem_id = data.pop("id") + problem_id = data.pop('id') try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") - if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists(): - return self.error("Display ID already exists") + return self.error('Display ID is required') + if ( + Problem.objects.exclude(id=problem_id) + .filter(_id=_id, contest_id__isnull=True) + .exists() + ): + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["languages"] = list(data["languages"]) + tags = data.pop('tags') + data['languages'] = list(data['languages']) for k, v in data.items(): setattr(problem, k, v) @@ -292,13 +321,13 @@ def put(self, request): @problem_permission_required def delete(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Invalid parameter, id is required") + return self.error('Invalid parameter, id is required') try: problem = Problem.objects.get(id=id, contest_id__isnull=True) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') ensure_created_by(problem, request.user) # d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) # if os.path.isdir(d): @@ -312,29 +341,29 @@ class ContestProblemAPI(ProblemBase): def post(self, request): data = request.data try: - contest = Contest.objects.get(id=data.pop("contest_id")) + contest = Contest.objects.get(id=data.pop('contest_id')) ensure_created_by(contest, request.user) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - if data["rule_type"] != contest.rule_type: - return self.error("Invalid rule type") + if data['rule_type'] != contest.rule_type: + return self.error('Invalid rule type') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") + return self.error('Display ID is required') if Problem.objects.filter(_id=_id, contest=contest).exists(): - return self.error("Duplicate Display id") + return self.error('Duplicate Display id') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - data["contest"] = contest - tags = data.pop("tags") - data["created_by"] = request.user + data['contest'] = contest + tags = data.pop('tags') + data['created_by'] = request.user problem = Problem.objects.create(**data) for item in tags: @@ -346,31 +375,33 @@ def post(self, request): return self.success(ProblemAdminSerializer(problem).data) def get(self, request): - problem_id = request.GET.get("id") - contest_id = request.GET.get("contest_id") + problem_id = request.GET.get('id') + contest_id = request.GET.get('contest_id') user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem.contest, user) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') return self.success(ProblemAdminSerializer(problem).data) if not contest_id: - return self.error("Contest id is required") + return self.error('Contest id is required') try: contest = Contest.objects.get(id=contest_id) ensure_created_by(contest, user) except Contest.DoesNotExist: - return self.error("Contest does not exist") - problems = Problem.objects.filter(contest=contest).order_by("-create_time") + return self.error('Contest does not exist') + problems = Problem.objects.filter(contest=contest).order_by('-create_time') if user.is_admin(): problems = problems.filter(contest__created_by=user) - keyword = request.GET.get("keyword") + keyword = request.GET.get('keyword') if keyword: problems = problems.filter(title__contains=keyword) - return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) + return self.success( + self.paginate_data(request, problems, ProblemAdminSerializer) + ) @validate_serializer(EditContestProblemSerializer) def put(self, request): @@ -378,33 +409,37 @@ def put(self, request): user = request.user try: - contest = Contest.objects.get(id=data.pop("contest_id")) + contest = Contest.objects.get(id=data.pop('contest_id')) ensure_created_by(contest, user) except Contest.DoesNotExist: - return self.error("Contest does not exist") + return self.error('Contest does not exist') - if data["rule_type"] != contest.rule_type: - return self.error("Invalid rule type") + if data['rule_type'] != contest.rule_type: + return self.error('Invalid rule type') - problem_id = data.pop("id") + problem_id = data.pop('id') try: problem = Problem.objects.get(id=problem_id, contest=contest) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - _id = data["_id"] + _id = data['_id'] if not _id: - return self.error("Display ID is required") - if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists(): - return self.error("Display ID already exists") + return self.error('Display ID is required') + if ( + Problem.objects.exclude(id=problem_id) + .filter(_id=_id, contest=contest) + .exists() + ): + return self.error('Display ID already exists') error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info - tags = data.pop("tags") - data["languages"] = list(data["languages"]) + tags = data.pop('tags') + data['languages'] = list(data['languages']) for k, v in data.items(): setattr(problem, k, v) @@ -420,13 +455,13 @@ def put(self, request): return self.success() def delete(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Invalid parameter, id is required") + return self.error('Invalid parameter, id is required') try: problem = Problem.objects.get(id=id, contest_id__isnull=False) except Problem.DoesNotExist: - return self.error("Problem does not exists") + return self.error('Problem does not exists') ensure_created_by(problem.contest, request.user) if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") @@ -442,17 +477,17 @@ class MakeContestProblemPublicAPIView(APIView): @problem_permission_required def post(self, request): data = request.data - display_id = data.get("display_id") + display_id = data.get('display_id') if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists(): - return self.error("Duplicate display ID") + return self.error('Duplicate display ID') try: - problem = Problem.objects.get(id=data["id"]) + problem = Problem.objects.get(id=data['id']) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') if not problem.contest or problem.is_public: - return self.error("Already be a public problem") + return self.error('Already be a public problem') problem.is_public = True problem.save() # https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances @@ -473,22 +508,22 @@ class AddContestProblemAPI(APIView): def post(self, request): data = request.data try: - contest = Contest.objects.get(id=data["contest_id"]) - problem = Problem.objects.get(id=data["problem_id"]) + contest = Contest.objects.get(id=data['contest_id']) + problem = Problem.objects.get(id=data['problem_id']) except (Contest.DoesNotExist, Problem.DoesNotExist): - return self.error("Contest or Problem does not exist") + return self.error('Contest or Problem does not exist') if contest.status == ContestStatus.CONTEST_ENDED: - return self.error("Contest has ended") - if Problem.objects.filter(contest=contest, _id=data["display_id"]).exists(): - return self.error("Duplicate display id in this contest") + return self.error('Contest has ended') + if Problem.objects.filter(contest=contest, _id=data['display_id']).exists(): + return self.error('Duplicate display id in this contest') tags = problem.tags.all() problem.pk = None problem.contest = contest problem.is_public = True problem.visible = True - problem._id = request.data["display_id"] + problem._id = request.data['display_id'] problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() @@ -500,49 +535,68 @@ class ExportProblemAPI(APIView): def choose_answers(self, user, problem): ret = [] for item in problem.languages: - submission = Submission.objects.filter(problem=problem, - user_id=user.id, - language=item, - result=JudgeStatus.ACCEPTED).order_by("-create_time").first() + submission = ( + Submission.objects.filter( + problem=problem, + user_id=user.id, + language=item, + result=JudgeStatus.ACCEPTED, + ) + .order_by('-create_time') + .first() + ) if submission: - ret.append({"language": submission.language, "code": submission.code}) + ret.append({'language': submission.language, 'code': submission.code}) return ret def process_one_problem(self, zip_file, user, problem, index): info = ExportProblemSerializer(problem).data - info["answers"] = self.choose_answers(user, problem=problem) + info['answers'] = self.choose_answers(user, problem=problem) compression = zipfile.ZIP_DEFLATED - zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json", - data=json.dumps(info, indent=4), - compress_type=compression) - problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) - with open(os.path.join(problem_test_case_dir, "info")) as f: + zip_file.writestr( + zinfo_or_arcname=f'{index}/problem.json', + data=json.dumps(info, indent=4), + compress_type=compression, + ) + problem_test_case_dir = os.path.join( + settings.TEST_CASE_DIR, problem.test_case_id + ) + with open(os.path.join(problem_test_case_dir, 'info')) as f: info = json.load(f) - for k, v in info["test_cases"].items(): - zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]), - arcname=f"{index}/testcase/{v['input_name']}", - compress_type=compression) - if not info["spj"]: - zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]), - arcname=f"{index}/testcase/{v['output_name']}", - compress_type=compression) + for k, v in info['test_cases'].items(): + zip_file.write( + filename=os.path.join(problem_test_case_dir, v['input_name']), + arcname=f"{index}/testcase/{v['input_name']}", + compress_type=compression, + ) + if not info['spj']: + zip_file.write( + filename=os.path.join(problem_test_case_dir, v['output_name']), + arcname=f"{index}/testcase/{v['output_name']}", + compress_type=compression, + ) @validate_serializer(ExportProblemRequestSerialzier) def get(self, request): - problems = Problem.objects.filter(id__in=request.data["problem_id"]) + problems = Problem.objects.filter(id__in=request.data['problem_id']) for problem in problems: if problem.contest: ensure_created_by(problem.contest, request.user) else: ensure_created_by(problem, request.user) - path = f"/tmp/{rand_str()}.zip" - with zipfile.ZipFile(path, "w") as zip_file: + path = f'/tmp/{rand_str()}.zip' + with zipfile.ZipFile(path, 'w') as zip_file: for index, problem in enumerate(problems): - self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1) + self.process_one_problem( + zip_file=zip_file, + user=request.user, + problem=problem, + index=index + 1, + ) delete_files.send_with_options(args=(path,), delay=300_000) - resp = FileResponse(open(path, "rb")) - resp["Content-Type"] = "application/zip" - resp["Content-Disposition"] = "attachment;filename=problem-export.zip" + resp = FileResponse(open(path, 'rb')) + resp['Content-Type'] = 'application/zip' + resp['Content-Disposition'] = 'attachment;filename=problem-export.zip' return resp @@ -552,128 +606,142 @@ class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor): def post(self, request): form = UploadProblemForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] - tmp_file = f"/tmp/{rand_str()}.zip" - with open(tmp_file, "wb") as f: + file = form.cleaned_data['file'] + tmp_file = f'/tmp/{rand_str()}.zip' + with open(tmp_file, 'wb') as f: for chunk in file: f.write(chunk) else: - return self.error("Upload failed") + return self.error('Upload failed') count = 0 - with zipfile.ZipFile(tmp_file, "r") as zip_file: + with zipfile.ZipFile(tmp_file, 'r') as zip_file: name_list = zip_file.namelist() for item in name_list: - if "/problem.json" in item: + if '/problem.json' in item: count += 1 with transaction.atomic(): for i in range(1, count + 1): - with zip_file.open(f"{i}/problem.json") as f: + with zip_file.open(f'{i}/problem.json') as f: problem_info = json.load(f) serializer = ImportProblemSerializer(data=problem_info) if not serializer.is_valid(): - return self.error(f"Invalid problem format, error is {serializer.errors}") + return self.error( + f'Invalid problem format, error is {serializer.errors}' + ) else: problem_info = serializer.data - for item in problem_info["template"].keys(): + for item in problem_info['template'].keys(): if item not in SysOptions.language_names: - return self.error(f"Unsupported language {item}") + return self.error(f'Unsupported language {item}') - problem_info["display_id"] = problem_info["display_id"][:24] - for k, v in problem_info["template"].items(): - problem_info["template"][k] = build_problem_template(v["prepend"], v["template"], - v["append"]) + problem_info['display_id'] = problem_info['display_id'][:24] + for k, v in problem_info['template'].items(): + problem_info['template'][k] = build_problem_template( + v['prepend'], v['template'], v['append'] + ) - spj = problem_info["spj"] is not None - rule_type = problem_info["rule_type"] - test_case_score = problem_info["test_case_score"] + spj = problem_info['spj'] is not None + rule_type = problem_info['rule_type'] + test_case_score = problem_info['test_case_score'] # process test case - _, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/") - - problem_obj = Problem.objects.create(_id=problem_info["display_id"], - title=problem_info["title"], - description=problem_info["description"]["value"], - input_description=problem_info["input_description"][ - "value"], - output_description=problem_info["output_description"][ - "value"], - hint=problem_info["hint"]["value"], - test_case_score=test_case_score if test_case_score else [], - time_limit=problem_info["time_limit"], - memory_limit=problem_info["memory_limit"], - samples=problem_info["samples"], - template=problem_info["template"], - rule_type=problem_info["rule_type"], - source=problem_info["source"], - spj=spj, - spj_code=problem_info["spj"]["code"] if spj else None, - spj_language=problem_info["spj"][ - "language"] if spj else None, - spj_version=rand_str(8) if spj else "", - languages=SysOptions.language_names, - created_by=request.user, - visible=False, - difficulty=Difficulty.MID, - total_score=sum(item["score"] for item in test_case_score) - if rule_type == ProblemRuleType.OI else 0, - test_case_id=test_case_id - ) - for tag_name in problem_info["tags"]: + _, test_case_id = self.process_zip( + tmp_file, spj=spj, dir=f'{i}/testcase/' + ) + + problem_obj = Problem.objects.create( + _id=problem_info['display_id'], + title=problem_info['title'], + description=problem_info['description']['value'], + input_description=problem_info['input_description'][ + 'value' + ], + output_description=problem_info['output_description'][ + 'value' + ], + hint=problem_info['hint']['value'], + test_case_score=test_case_score if test_case_score else [], + time_limit=problem_info['time_limit'], + memory_limit=problem_info['memory_limit'], + samples=problem_info['samples'], + template=problem_info['template'], + rule_type=problem_info['rule_type'], + source=problem_info['source'], + spj=spj, + spj_code=problem_info['spj']['code'] if spj else None, + spj_language=problem_info['spj']['language'] + if spj + else None, + spj_version=rand_str(8) if spj else '', + languages=SysOptions.language_names, + created_by=request.user, + visible=False, + difficulty=Difficulty.MID, + total_score=sum(item['score'] for item in test_case_score) + if rule_type == ProblemRuleType.OI + else 0, + test_case_id=test_case_id, + ) + for tag_name in problem_info['tags']: tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name) problem_obj.tags.add(tag_obj) - return self.success({"import_count": count}) + return self.success({'import_count': count}) class FPSProblemImport(CSRFExemptAPIView): request_parsers = () def _create_problem(self, problem_data, creator): - if problem_data["time_limit"]["unit"] == "ms": - time_limit = problem_data["time_limit"]["value"] + if problem_data['time_limit']['unit'] == 'ms': + time_limit = problem_data['time_limit']['value'] else: - time_limit = problem_data["time_limit"]["value"] * 1000 + time_limit = problem_data['time_limit']['value'] * 1000 template = {} prepend = {} append = {} - for t in problem_data["prepend"]: - prepend[t["language"]] = t["code"] - for t in problem_data["append"]: - append[t["language"]] = t["code"] - for t in problem_data["template"]: - our_lang = lang = t["language"] - if lang == "Python": - our_lang = "Python3" - template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, "")) - spj = problem_data["spj"] is not None - Problem.objects.create(_id=f"fps-{rand_str(4)}", - title=problem_data["title"], - description=problem_data["description"], - input_description=problem_data["input"], - output_description=problem_data["output"], - hint=problem_data["hint"], - test_case_score=problem_data["test_case_score"], - time_limit=time_limit, - memory_limit=problem_data["memory_limit"]["value"], - samples=problem_data["samples"], - template=template, - rule_type=ProblemRuleType.ACM, - source=problem_data.get("source", ""), - spj=spj, - spj_code=problem_data["spj"]["code"] if spj else None, - spj_language=problem_data["spj"]["language"] if spj else None, - spj_version=rand_str(8) if spj else "", - visible=False, - languages=SysOptions.language_names, - created_by=creator, - difficulty=Difficulty.MID, - test_case_id=problem_data["test_case_id"]) + for t in problem_data['prepend']: + prepend[t['language']] = t['code'] + for t in problem_data['append']: + append[t['language']] = t['code'] + for t in problem_data['template']: + our_lang = lang = t['language'] + if lang == 'Python': + our_lang = 'Python3' + template[our_lang] = TEMPLATE_BASE.format( + prepend.get(lang, ''), t['code'], append.get(lang, '') + ) + spj = problem_data['spj'] is not None + Problem.objects.create( + _id=f'fps-{rand_str(4)}', + title=problem_data['title'], + description=problem_data['description'], + input_description=problem_data['input'], + output_description=problem_data['output'], + hint=problem_data['hint'], + test_case_score=problem_data['test_case_score'], + time_limit=time_limit, + memory_limit=problem_data['memory_limit']['value'], + samples=problem_data['samples'], + template=template, + rule_type=ProblemRuleType.ACM, + source=problem_data.get('source', ''), + spj=spj, + spj_code=problem_data['spj']['code'] if spj else None, + spj_language=problem_data['spj']['language'] if spj else None, + spj_version=rand_str(8) if spj else '', + visible=False, + languages=SysOptions.language_names, + created_by=creator, + difficulty=Difficulty.MID, + test_case_id=problem_data['test_case_id'], + ) def post(self, request): form = UploadProblemForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] - with tempfile.NamedTemporaryFile("wb") as tf: + file = form.cleaned_data['file'] + with tempfile.NamedTemporaryFile('wb') as tf: for chunk in file.chunks(4096): tf.file.write(chunk) @@ -682,7 +750,7 @@ def post(self, request): problems = FPSParser(tf.name).parse() else: - return self.error("Parse upload file error") + return self.error('Parse upload file error') helper = FPSHelper() with transaction.atomic(): @@ -691,15 +759,24 @@ def post(self, request): test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) os.mkdir(test_case_dir) score = [] - for item in helper.save_test_case(_problem, test_case_dir)["test_cases"].values(): - score.append({"score": 0, "input_name": item["input_name"], - "output_name": item.get("output_name")}) - problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX) + for item in helper.save_test_case(_problem, test_case_dir)[ + 'test_cases' + ].values(): + score.append( + { + 'score': 0, + 'input_name': item['input_name'], + 'output_name': item.get('output_name'), + } + ) + problem_data = helper.save_image( + _problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX + ) s = FPSProblemSerializer(data=problem_data) if not s.is_valid(): - return self.error(f"Parse FPS file error: {s.errors}") + return self.error(f'Parse FPS file error: {s.errors}') problem_data = s.data - problem_data["test_case_id"] = test_case_id - problem_data["test_case_score"] = score + problem_data['test_case_id'] = test_case_id + problem_data['test_case_score'] = score self._create_problem(problem_data, request.user) - return self.success({"import_count": len(problems)}) + return self.success({'import_count': len(problems)}) diff --git a/problem/views/oj.py b/problem/views/oj.py index 534fa6059..0ecd50dbf 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -10,10 +10,10 @@ class ProblemTagAPI(APIView): def get(self, request): qs = ProblemTag.objects - keyword = request.GET.get("keyword") + keyword = request.GET.get('keyword') if keyword: qs = ProblemTag.objects.filter(name__icontains=keyword) - tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0) + tags = qs.annotate(problem_count=Count('problem')).filter(problem_count__gt=0) return self.success(TagSerializer(tags, many=True).data) @@ -22,7 +22,7 @@ def get(self, request): problems = Problem.objects.filter(contest_id__isnull=True, visible=True) count = problems.count() if count == 0: - return self.error("No problem to pick") + return self.error('No problem to pick') return self.success(problems[random.randint(0, count - 1)]._id) @@ -31,50 +31,61 @@ class ProblemAPI(APIView): def _add_problem_status(request, queryset_values): if request.user.is_authenticated: profile = request.user.userprofile - acm_problems_status = profile.acm_problems_status.get("problems", {}) - oi_problems_status = profile.oi_problems_status.get("problems", {}) + acm_problems_status = profile.acm_problems_status.get('problems', {}) + oi_problems_status = profile.oi_problems_status.get('problems', {}) # paginate data - results = queryset_values.get("results") + results = queryset_values.get('results') if results is not None: problems = results else: - problems = [queryset_values, ] + problems = [ + queryset_values, + ] for problem in problems: - if problem["rule_type"] == ProblemRuleType.ACM: - problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") + if problem['rule_type'] == ProblemRuleType.ACM: + problem['my_status'] = acm_problems_status.get( + str(problem['id']), {} + ).get('status') else: - problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") + problem['my_status'] = oi_problems_status.get( + str(problem['id']), {} + ).get('status') def get(self, request): # 问题详情页 - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if problem_id: try: - problem = Problem.objects.select_related("created_by") \ - .get(_id=problem_id, contest_id__isnull=True, visible=True) + problem = Problem.objects.select_related('created_by').get( + _id=problem_id, contest_id__isnull=True, visible=True + ) problem_data = ProblemSerializer(problem).data self._add_problem_status(request, problem_data) return self.success(problem_data) except Problem.DoesNotExist: - return self.error("Problem does not exist") + return self.error('Problem does not exist') - limit = request.GET.get("limit") + limit = request.GET.get('limit') if not limit: - return self.error("Limit is needed") + return self.error('Limit is needed') - problems = Problem.objects.select_related("created_by").filter(contest_id__isnull=True, visible=True) + problems = Problem.objects.select_related('created_by').filter( + contest_id__isnull=True, visible=True + ) # 按照标签筛选 - tag_text = request.GET.get("tag") + tag_text = request.GET.get('tag') if tag_text: problems = problems.filter(tags__name=tag_text) # 搜索的情况 - keyword = request.GET.get("keyword", "").strip() + keyword = request.GET.get('keyword', '').strip() if keyword: - problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword)) + problems = problems.filter( + Q(title__icontains=keyword) | Q(_id__icontains=keyword) + ) # 难度筛选 - difficulty = request.GET.get("difficulty") + difficulty = request.GET.get('difficulty') if difficulty: problems = problems.filter(difficulty=difficulty) # 根据profile 为做过的题目添加标记 @@ -88,30 +99,41 @@ def _add_problem_status(self, request, queryset_values): if request.user.is_authenticated: profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: - problems_status = profile.acm_problems_status.get("contest_problems", {}) + problems_status = profile.acm_problems_status.get( + 'contest_problems', {} + ) else: - problems_status = profile.oi_problems_status.get("contest_problems", {}) + problems_status = profile.oi_problems_status.get('contest_problems', {}) for problem in queryset_values: - problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") + problem['my_status'] = problems_status.get(str(problem['id']), {}).get( + 'status' + ) - @check_contest_permission(check_type="problems") + @check_contest_permission(check_type='problems') def get(self, request): - problem_id = request.GET.get("problem_id") + problem_id = request.GET.get('problem_id') if problem_id: try: - problem = Problem.objects.select_related("created_by").get(_id=problem_id, - contest=self.contest, - visible=True) + problem = Problem.objects.select_related('created_by').get( + _id=problem_id, contest=self.contest, visible=True + ) except Problem.DoesNotExist: - return self.error("Problem does not exist.") + return self.error('Problem does not exist.') if self.contest.problem_details_permission(request.user): problem_data = ProblemSerializer(problem).data - self._add_problem_status(request, [problem_data, ]) + self._add_problem_status( + request, + [ + problem_data, + ], + ) else: problem_data = ProblemSafeSerializer(problem).data return self.success(problem_data) - contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) + contest_problems = Problem.objects.select_related('created_by').filter( + contest=self.contest, visible=True + ) if self.contest.problem_details_permission(request.user): data = ProblemSerializer(contest_problems, many=True).data self._add_problem_status(request, data) diff --git a/run_test.py b/run_test.py index a6c714f6e..627781742 100644 --- a/run_test.py +++ b/run_test.py @@ -2,26 +2,30 @@ import os import sys -opts, args = getopt.getopt(sys.argv[1:], "cm:", ["coverage=", "module="]) +opts, args = getopt.getopt(sys.argv[1:], 'cm:', ['coverage=', 'module=']) is_coverage = False -test_module = "" -setting = "oj.settings" +test_module = '' +setting = 'oj.settings' for opt, arg in opts: - if opt in ["-c", "--coverage"]: + if opt in ['-c', '--coverage']: is_coverage = True - if opt in ["-m", "--module"]: + if opt in ['-m', '--module']: test_module = arg -print("Coverage: {cov}".format(cov=is_coverage)) -print("Module: {mod}".format(mod=(test_module if test_module else "All"))) +print('Coverage: {cov}'.format(cov=is_coverage)) +print('Module: {mod}'.format(mod=(test_module if test_module else 'All'))) -print("running flake8...") -if os.system("flake8 --statistics ."): +print('running flake8...') +if os.system('flake8 --statistics .'): exit() -ret = os.system('coverage run --include="$PWD/*" manage.py test {module} --settings={setting}'.format(module=test_module, setting=setting)) +ret = os.system( + 'coverage run --include="$PWD/*" manage.py test {module} --settings={setting}'.format( + module=test_module, setting=setting + ) +) if not ret and is_coverage: - os.system("coverage html && open htmlcov/index.html") + os.system('coverage html && open htmlcov/index.html') diff --git a/submission/migrations/0001_initial.py b/submission/migrations/0001_initial.py index 42a5352f2..63ceb7c34 100644 --- a/submission/migrations/0001_initial.py +++ b/submission/migrations/0001_initial.py @@ -9,17 +9,24 @@ class Migration(migrations.Migration): - initial = True - dependencies = [ - ] + dependencies = [] operations = [ migrations.CreateModel( name='Submission', fields=[ - ('id', models.CharField(db_index=True, default=utils.shortcuts.rand_str, max_length=32, primary_key=True, serialize=False)), + ( + 'id', + models.CharField( + db_index=True, + default=utils.shortcuts.rand_str, + max_length=32, + primary_key=True, + serialize=False, + ), + ), ('contest_id', models.IntegerField(db_index=True, null=True)), ('problem_id', models.IntegerField(db_index=True)), ('created_time', models.DateTimeField(auto_now_add=True)), diff --git a/submission/migrations/0002_auto_20170509_1203.py b/submission/migrations/0002_auto_20170509_1203.py index 78dcbe9bc..c5642f19d 100644 --- a/submission/migrations/0002_auto_20170509_1203.py +++ b/submission/migrations/0002_auto_20170509_1203.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0001_initial'), ] @@ -34,5 +33,5 @@ class Migration(migrations.Migration): migrations.AlterModelOptions( name='submission', options={'ordering': ('-create_time',)}, - ) + ), ] diff --git a/submission/migrations/0005_submission_username.py b/submission/migrations/0005_submission_username.py index 68a324357..24c341a5b 100644 --- a/submission/migrations/0005_submission_username.py +++ b/submission/migrations/0005_submission_username.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0002_auto_20170509_1203'), ] @@ -15,7 +14,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='submission', name='username', - field=models.CharField(default="", max_length=30), + field=models.CharField(default='', max_length=30), preserve_default=False, ), ] diff --git a/submission/migrations/0006_auto_20170830_1154.py b/submission/migrations/0006_auto_20170830_1154.py index 675cc8659..b492924f8 100644 --- a/submission/migrations/0006_auto_20170830_1154.py +++ b/submission/migrations/0006_auto_20170830_1154.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0005_submission_username'), ] diff --git a/submission/migrations/0007_auto_20170923_1318.py b/submission/migrations/0007_auto_20170923_1318.py index 635668074..6b8059d52 100644 --- a/submission/migrations/0007_auto_20170923_1318.py +++ b/submission/migrations/0007_auto_20170923_1318.py @@ -8,7 +8,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0006_auto_20170830_1154'), ] @@ -17,12 +16,18 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='submission', name='contest_id', - field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'), + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + to='contest.Contest', + ), ), migrations.AlterField( model_name='submission', name='problem_id', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.Problem'), + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='problem.Problem' + ), ), migrations.RenameField( model_name='submission', diff --git a/submission/migrations/0008_submission_ip.py b/submission/migrations/0008_submission_ip.py index e60841bdc..44f2de468 100644 --- a/submission/migrations/0008_submission_ip.py +++ b/submission/migrations/0008_submission_ip.py @@ -6,7 +6,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0007_auto_20170923_1318'), ] diff --git a/submission/migrations/0009_delete_user_output.py b/submission/migrations/0009_delete_user_output.py index 66c90c41c..26d1843c8 100644 --- a/submission/migrations/0009_delete_user_output.py +++ b/submission/migrations/0009_delete_user_output.py @@ -5,16 +5,15 @@ def delete_user_output(apps, schema_editor): - Submission = apps.get_model("submission", "Submission") + Submission = apps.get_model('submission', 'Submission') for item in Submission.objects.all(): - if "data" in item.info and isinstance(item.info["data"], list): - for index in range(len(item.info["data"])): - item.info["data"][index]["output"] = "" + if 'data' in item.info and isinstance(item.info['data'], list): + for index in range(len(item.info['data'])): + item.info['data'][index]['output'] = '' item.save() class Migration(migrations.Migration): - dependencies = [ ('submission', '0008_submission_ip'), ] diff --git a/submission/migrations/0011_fix_submission_number.py b/submission/migrations/0011_fix_submission_number.py index f3d5e6b76..3879f0624 100644 --- a/submission/migrations/0011_fix_submission_number.py +++ b/submission/migrations/0011_fix_submission_number.py @@ -5,15 +5,15 @@ def fix_rejudge_bugs(apps, schema_editor): - Submission = apps.get_model("submission", "Submission") - User = apps.get_model("account", "User") + Submission = apps.get_model('submission', 'Submission') + User = apps.get_model('account', 'User') for user in User.objects.all(): submissions = Submission.objects.filter(user_id=user.id, contest__isnull=True) profile = user.userprofile profile.submission_number = submissions.count() profile.accepted_number = submissions.filter(result=0).count() - profile.save(update_fields=["submission_number", "accepted_number"]) + profile.save(update_fields=['submission_number', 'accepted_number']) class Migration(migrations.Migration): diff --git a/submission/migrations/0012_auto_20180501_0436.py b/submission/migrations/0012_auto_20180501_0436.py index b81087298..78ffbb720 100644 --- a/submission/migrations/0012_auto_20180501_0436.py +++ b/submission/migrations/0012_auto_20180501_0436.py @@ -7,7 +7,6 @@ class Migration(migrations.Migration): - dependencies = [ ('submission', '0011_fix_submission_number'), ] @@ -16,7 +15,12 @@ class Migration(migrations.Migration): migrations.AlterField( model_name='submission', name='id', - field=models.TextField(db_index=True, default=utils.shortcuts.rand_str, primary_key=True, serialize=False), + field=models.TextField( + db_index=True, + default=utils.shortcuts.rand_str, + primary_key=True, + serialize=False, + ), ), migrations.AlterField( model_name='submission', diff --git a/submission/models.py b/submission/models.py index 2918f8808..857eac677 100644 --- a/submission/models.py +++ b/submission/models.py @@ -41,7 +41,12 @@ class Submission(models.Model): ip = models.TextField(null=True) def check_user_permission(self, user, check_share=True): - if self.user_id == user.id or user.is_super_admin() or user.can_mgmt_all_problem() or self.problem.created_by_id == user.id: + if ( + self.user_id == user.id + or user.is_super_admin() + or user.can_mgmt_all_problem() + or self.problem.created_by_id == user.id + ): return True if check_share: @@ -52,8 +57,8 @@ def check_user_permission(self, user, check_share=True): return False class Meta: - db_table = "submission" - ordering = ("-create_time",) + db_table = 'submission' + ordering = ('-create_time',) def __str__(self): return self.id diff --git a/submission/serializers.py b/submission/serializers.py index 5e48f3e76..47a00816b 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -17,32 +17,31 @@ class ShareSubmissionSerializer(serializers.Serializer): class SubmissionModelSerializer(serializers.ModelSerializer): - class Meta: model = Submission - fields = "__all__" + fields = '__all__' # 不显示submission info的serializer, 用于ACM rule_type class SubmissionSafeModelSerializer(serializers.ModelSerializer): - problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") + problem = serializers.SlugRelatedField(read_only=True, slug_field='_id') class Meta: model = Submission - exclude = ("info", "contest", "ip") + exclude = ('info', 'contest', 'ip') class SubmissionListSerializer(serializers.ModelSerializer): - problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") + problem = serializers.SlugRelatedField(read_only=True, slug_field='_id') show_link = serializers.SerializerMethodField() def __init__(self, *args, **kwargs): - self.user = kwargs.pop("user", None) + self.user = kwargs.pop('user', None) super().__init__(*args, **kwargs) class Meta: model = Submission - exclude = ("info", "contest", "code", "ip") + exclude = ('info', 'contest', 'code', 'ip') def get_show_link(self, obj): # 没传user或为匿名user diff --git a/submission/tests.py b/submission/tests.py index 08ccbd46a..292c25257 100644 --- a/submission/tests.py +++ b/submission/tests.py @@ -5,25 +5,48 @@ from utils.api.tests import APITestCase from .models import Submission -DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "

test

", "input_description": "test", - "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low", - "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {}, - "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C", - "spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85", - "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, - "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", - "input_size": 0, "score": 0}], - "rule_type": "ACM", "hint": "

test

", "source": "test"} +DEFAULT_PROBLEM_DATA = { + '_id': 'A-110', + 'title': 'test', + 'description': '

test

', + 'input_description': 'test', + 'output_description': 'test', + 'time_limit': 1000, + 'memory_limit': 256, + 'difficulty': 'Low', + 'visible': True, + 'tags': ['test'], + 'languages': ['C', 'C++', 'Java', 'Python2'], + 'template': {}, + 'samples': [{'input': 'test', 'output': 'test'}], + 'spj': False, + 'spj_language': 'C', + 'spj_code': '', + 'test_case_id': '499b26290cc7994e0b497212e842ea85', + 'test_case_score': [ + { + 'output_name': '1.out', + 'input_name': '1.in', + 'output_size': 0, + 'stripped_output_md5': 'd41d8cd98f00b204e9800998ecf8427e', + 'input_size': 0, + 'score': 0, + } + ], + 'rule_type': 'ACM', + 'hint': '

test

', + 'source': 'test', +} DEFAULT_SUBMISSION_DATA = { - "problem_id": "1", - "user_id": 1, - "username": "test", - "code": "xxxxxxxxxxxxxx", - "result": -2, - "info": {}, - "language": "C", - "statistic_info": {} + 'problem_id': '1', + 'user_id': 1, + 'username': 'test', + 'code': 'xxxxxxxxxxxxxx', + 'result': -2, + 'info': {}, + 'language': 'C', + 'statistic_info': {}, } @@ -32,37 +55,37 @@ class SubmissionPrepare(APITestCase): def _create_problem_and_submission(self): - user = self.create_admin("test", "test123", login=False) + user = self.create_admin('test', 'test123', login=False) problem_data = deepcopy(DEFAULT_PROBLEM_DATA) - tags = problem_data.pop("tags") - problem_data["created_by"] = user + tags = problem_data.pop('tags') + problem_data['created_by'] = user self.problem = Problem.objects.create(**problem_data) for tag in tags: tag = ProblemTag.objects.create(name=tag) self.problem.tags.add(tag) self.problem.save() self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA) - self.submission_data["problem_id"] = self.problem.id + self.submission_data['problem_id'] = self.problem.id self.submission = Submission.objects.create(**self.submission_data) class SubmissionListTest(SubmissionPrepare): def setUp(self): self._create_problem_and_submission() - self.create_user("123", "345") - self.url = self.reverse("submission_list_api") + self.create_user('123', '345') + self.url = self.reverse('submission_list_api') def test_get_submission_list(self): - resp = self.client.get(self.url, data={"limit": "10"}) + resp = self.client.get(self.url, data={'limit': '10'}) self.assertSuccess(resp) -@mock.patch("submission.views.oj.judge_task.send") +@mock.patch('submission.views.oj.judge_task.send') class SubmissionAPITest(SubmissionPrepare): def setUp(self): self._create_problem_and_submission() - self.user = self.create_user("123", "test123") - self.url = self.reverse("submission_api") + self.user = self.create_user('123', 'test123') + self.url = self.reverse('submission_api') def test_create_submission(self, judge_task): resp = self.client.post(self.url, self.submission_data) @@ -70,9 +93,11 @@ def test_create_submission(self, judge_task): judge_task.assert_called() def test_create_submission_with_wrong_language(self, judge_task): - self.submission_data.update({"language": "Python3"}) + self.submission_data.update({'language': 'Python3'}) resp = self.client.post(self.url, self.submission_data) self.assertFailed(resp) - self.assertDictEqual(resp.data, {"error": "error", - "data": "Python3 is now allowed in the problem"}) + self.assertDictEqual( + resp.data, + {'error': 'error', 'data': 'Python3 is now allowed in the problem'}, + ) judge_task.assert_not_called() diff --git a/submission/urls/admin.py b/submission/urls/admin.py index bf86022f6..55de961ed 100644 --- a/submission/urls/admin.py +++ b/submission/urls/admin.py @@ -3,5 +3,9 @@ from ..views.admin import SubmissionRejudgeAPI urlpatterns = [ - url(r"^submission/rejudge?$", SubmissionRejudgeAPI.as_view(), name="submission_rejudge_api"), + url( + r'^submission/rejudge?$', + SubmissionRejudgeAPI.as_view(), + name='submission_rejudge_api', + ), ] diff --git a/submission/urls/oj.py b/submission/urls/oj.py index 49116b9ff..a975c8527 100644 --- a/submission/urls/oj.py +++ b/submission/urls/oj.py @@ -1,10 +1,23 @@ from django.conf.urls import url -from ..views.oj import SubmissionAPI, SubmissionListAPI, ContestSubmissionListAPI, SubmissionExistsAPI +from ..views.oj import ( + SubmissionAPI, + SubmissionListAPI, + ContestSubmissionListAPI, + SubmissionExistsAPI, +) urlpatterns = [ - url(r"^submission/?$", SubmissionAPI.as_view(), name="submission_api"), - url(r"^submissions/?$", SubmissionListAPI.as_view(), name="submission_list_api"), - url(r"^submission_exists/?$", SubmissionExistsAPI.as_view(), name="submission_exists"), - url(r"^contest_submissions/?$", ContestSubmissionListAPI.as_view(), name="contest_submission_list_api"), + url(r'^submission/?$', SubmissionAPI.as_view(), name='submission_api'), + url(r'^submissions/?$', SubmissionListAPI.as_view(), name='submission_list_api'), + url( + r'^submission_exists/?$', + SubmissionExistsAPI.as_view(), + name='submission_exists', + ), + url( + r'^contest_submissions/?$', + ContestSubmissionListAPI.as_view(), + name='contest_submission_list_api', + ), ] diff --git a/submission/views/admin.py b/submission/views/admin.py index 879725614..2e78c417f 100644 --- a/submission/views/admin.py +++ b/submission/views/admin.py @@ -1,5 +1,6 @@ from account.decorators import super_admin_required from judge.tasks import judge_task + # from judge.dispatcher import JudgeDispatcher from utils.api import APIView from ..models import Submission @@ -8,13 +9,15 @@ class SubmissionRejudgeAPI(APIView): @super_admin_required def get(self, request): - id = request.GET.get("id") + id = request.GET.get('id') if not id: - return self.error("Parameter error, id is required") + return self.error('Parameter error, id is required') try: - submission = Submission.objects.select_related("problem").get(id=id, contest_id__isnull=True) + submission = Submission.objects.select_related('problem').get( + id=id, contest_id__isnull=True + ) except Submission.DoesNotExist: - return self.error("Submission does not exists") + return self.error('Submission does not exists') submission.statistic_info = {} submission.save() diff --git a/submission/views/oj.py b/submission/views/oj.py index 0f7c4a584..63ea8e6f0 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -4,6 +4,7 @@ from contest.models import ContestStatus, ContestRuleType from judge.tasks import judge_task from options.options import SysOptions + # from judge.dispatcher import JudgeDispatcher from problem.models import Problem, ProblemRuleType from utils.api import APIView, validate_serializer @@ -11,22 +12,26 @@ from utils.captcha import Captcha from utils.throttling import TokenBucket from ..models import Submission -from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, - ShareSubmissionSerializer) +from ..serializers import ( + CreateSubmissionSerializer, + SubmissionModelSerializer, + ShareSubmissionSerializer, +) from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer class SubmissionAPI(APIView): def throttling(self, request): # 使用 open_api 的请求暂不做限制 - auth_method = getattr(request, "auth_method", "") - if auth_method == "api_key": + auth_method = getattr(request, 'auth_method', '') + if auth_method == 'api_key': return - user_bucket = TokenBucket(key=str(request.user.id), - redis_conn=cache, **SysOptions.throttling["user"]) + user_bucket = TokenBucket( + key=str(request.user.id), redis_conn=cache, **SysOptions.throttling['user'] + ) can_consume, wait = user_bucket.consume() if not can_consume: - return "Please wait %d seconds" % (int(wait)) + return 'Please wait %d seconds' % (int(wait)) # ip_bucket = TokenBucket(key=request.session["ip"], # redis_conn=cache, **SysOptions.throttling["ip"]) @@ -34,23 +39,26 @@ def throttling(self, request): # if not can_consume: # return "Captcha is required" - @check_contest_permission(check_type="problems") + @check_contest_permission(check_type='problems') def check_contest_permission(self, request): contest = self.contest if contest.status == ContestStatus.CONTEST_ENDED: - return self.error("The contest have ended") + return self.error('The contest have ended') if not request.user.is_contest_admin(contest): - user_ip = ipaddress.ip_address(request.session.get("ip")) + user_ip = ipaddress.ip_address(request.session.get('ip')) if contest.allowed_ip_ranges: - if not any(user_ip in ipaddress.ip_network(cidr, strict=False) for cidr in contest.allowed_ip_ranges): - return self.error("Your IP is not allowed in this contest") + if not any( + user_ip in ipaddress.ip_network(cidr, strict=False) + for cidr in contest.allowed_ip_ranges + ): + return self.error('Your IP is not allowed in this contest') @validate_serializer(CreateSubmissionSerializer) @login_required def post(self, request): data = request.data hide_id = False - if data.get("contest_id"): + if data.get('contest_id'): error = self.check_contest_permission(request) if error: return error @@ -58,52 +66,63 @@ def post(self, request): if not contest.problem_details_permission(request.user): hide_id = True - if data.get("captcha"): - if not Captcha(request).check(data["captcha"]): - return self.error("Invalid captcha") + if data.get('captcha'): + if not Captcha(request).check(data['captcha']): + return self.error('Invalid captcha') error = self.throttling(request) if error: return self.error(error) try: - problem = Problem.objects.get(id=data["problem_id"], contest_id=data.get("contest_id"), visible=True) + problem = Problem.objects.get( + id=data['problem_id'], contest_id=data.get('contest_id'), visible=True + ) except Problem.DoesNotExist: - return self.error("Problem not exist") - if data["language"] not in problem.languages: + return self.error('Problem not exist') + if data['language'] not in problem.languages: return self.error(f"{data['language']} is now allowed in the problem") - submission = Submission.objects.create(user_id=request.user.id, - username=request.user.username, - language=data["language"], - code=data["code"], - problem_id=problem.id, - ip=request.session["ip"], - contest_id=data.get("contest_id")) + submission = Submission.objects.create( + user_id=request.user.id, + username=request.user.username, + language=data['language'], + code=data['code'], + problem_id=problem.id, + ip=request.session['ip'], + contest_id=data.get('contest_id'), + ) # use this for debug # JudgeDispatcher(submission.id, problem.id).judge() judge_task.send(submission.id, problem.id) if hide_id: return self.success() else: - return self.success({"submission_id": submission.id}) + return self.success({'submission_id': submission.id}) @login_required def get(self, request): - submission_id = request.GET.get("id") + submission_id = request.GET.get('id') if not submission_id: return self.error("Parameter id doesn't exist") try: - submission = Submission.objects.select_related("problem").get(id=submission_id) + submission = Submission.objects.select_related('problem').get( + id=submission_id + ) except Submission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user): - return self.error("No permission for this submission") + return self.error('No permission for this submission') - if submission.problem.rule_type == ProblemRuleType.OI or request.user.is_admin_role(): + if ( + submission.problem.rule_type == ProblemRuleType.OI + or request.user.is_admin_role() + ): submission_data = SubmissionModelSerializer(submission).data else: submission_data = SubmissionSafeModelSerializer(submission).data # 是否有权限取消共享 - submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False) + submission_data['can_unshare'] = submission.check_user_permission( + request.user, check_share=False + ) return self.success(submission_data) @validate_serializer(ShareSubmissionSerializer) @@ -113,67 +132,82 @@ def put(self, request): share submission """ try: - submission = Submission.objects.select_related("problem").get(id=request.data["id"]) + submission = Submission.objects.select_related('problem').get( + id=request.data['id'] + ) except Submission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user, check_share=False): - return self.error("No permission to share the submission") - if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY: - return self.error("Can not share submission now") - submission.shared = request.data["shared"] - submission.save(update_fields=["shared"]) + return self.error('No permission to share the submission') + if ( + submission.contest + and submission.contest.status == ContestStatus.CONTEST_UNDERWAY + ): + return self.error('Can not share submission now') + submission.shared = request.data['shared'] + submission.save(update_fields=['shared']) return self.success() class SubmissionListAPI(APIView): def get(self, request): - if not request.GET.get("limit"): - return self.error("Limit is needed") - if request.GET.get("contest_id"): - return self.error("Parameter error") - - submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by") - problem_id = request.GET.get("problem_id") - myself = request.GET.get("myself") - result = request.GET.get("result") - username = request.GET.get("username") + if not request.GET.get('limit'): + return self.error('Limit is needed') + if request.GET.get('contest_id'): + return self.error('Parameter error') + + submissions = Submission.objects.filter(contest_id__isnull=True).select_related( + 'problem__created_by' + ) + problem_id = request.GET.get('problem_id') + myself = request.GET.get('myself') + result = request.GET.get('result') + username = request.GET.get('username') if problem_id: try: - problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True) + problem = Problem.objects.get( + _id=problem_id, contest_id__isnull=True, visible=True + ) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = submissions.filter(problem=problem) - if (myself and myself == "1") or not SysOptions.submission_list_show_all: + if (myself and myself == '1') or not SysOptions.submission_list_show_all: submissions = submissions.filter(user_id=request.user.id) elif username: submissions = submissions.filter(username__icontains=username) if result: submissions = submissions.filter(result=result) data = self.paginate_data(request, submissions) - data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data + data['results'] = SubmissionListSerializer( + data['results'], many=True, user=request.user + ).data return self.success(data) class ContestSubmissionListAPI(APIView): - @check_contest_permission(check_type="submissions") + @check_contest_permission(check_type='submissions') def get(self, request): - if not request.GET.get("limit"): - return self.error("Limit is needed") + if not request.GET.get('limit'): + return self.error('Limit is needed') contest = self.contest - submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") - problem_id = request.GET.get("problem_id") - myself = request.GET.get("myself") - result = request.GET.get("result") - username = request.GET.get("username") + submissions = Submission.objects.filter(contest_id=contest.id).select_related( + 'problem__created_by' + ) + problem_id = request.GET.get('problem_id') + myself = request.GET.get('myself') + result = request.GET.get('result') + username = request.GET.get('username') if problem_id: try: - problem = Problem.objects.get(_id=problem_id, contest_id=contest.id, visible=True) + problem = Problem.objects.get( + _id=problem_id, contest_id=contest.id, visible=True + ) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = submissions.filter(problem=problem) - if myself and myself == "1": + if myself and myself == '1': submissions = submissions.filter(user_id=request.user.id) elif username: submissions = submissions.filter(username__icontains=username) @@ -186,18 +220,25 @@ def get(self, request): # 封榜的时候只能看到自己的提交 if contest.rule_type == ContestRuleType.ACM: - if not contest.real_time_rank and not request.user.is_contest_admin(contest): + if not contest.real_time_rank and not request.user.is_contest_admin( + contest + ): submissions = submissions.filter(user_id=request.user.id) data = self.paginate_data(request, submissions) - data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data + data['results'] = SubmissionListSerializer( + data['results'], many=True, user=request.user + ).data return self.success(data) class SubmissionExistsAPI(APIView): def get(self, request): - if not request.GET.get("problem_id"): - return self.error("Parameter error, problem_id is required") - return self.success(request.user.is_authenticated and - Submission.objects.filter(problem_id=request.GET["problem_id"], - user_id=request.user.id).exists()) + if not request.GET.get('problem_id'): + return self.error('Parameter error, problem_id is required') + return self.success( + request.user.is_authenticated + and Submission.objects.filter( + problem_id=request.GET['problem_id'], user_id=request.user.id + ).exists() + ) diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py index 745c69601..8827c6b1c 100644 --- a/utils/api/_serializers.py +++ b/utils/api/_serializers.py @@ -7,7 +7,7 @@ class UsernameSerializer(serializers.Serializer): real_name = serializers.SerializerMethodField() def __init__(self, *args, **kwargs): - self.need_real_name = kwargs.pop("need_real_name", False) + self.need_real_name = kwargs.pop('need_real_name', False) super().__init__(*args, **kwargs) def get_real_name(self, obj): diff --git a/utils/api/api.py b/utils/api/api.py index a603bfb97..fb192395f 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -7,7 +7,7 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View -logger = logging.getLogger("") +logger = logging.getLogger('') class APIError(Exception): @@ -18,10 +18,10 @@ def __init__(self, msg, err=None): class ContentType(object): - json_request = "application/json" - json_response = "application/json;charset=UTF-8" - url_encoded_request = "application/x-www-form-urlencoded" - binary_response = "application/octet-stream" + json_request = 'application/json' + json_response = 'application/json;charset=UTF-8' + url_encoded_request = 'application/x-www-form-urlencoded' + binary_response = 'application/octet-stream' class JSONParser(object): @@ -29,7 +29,7 @@ class JSONParser(object): @staticmethod def parse(body): - return json.loads(body.decode("utf-8")) + return json.loads(body.decode('utf-8')) class URLEncodedParser(object): @@ -59,15 +59,16 @@ class APIView(View): - self.response 返回一个django HttpResponse, 具体在self.response_class中实现 - parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据 """ + request_parsers = (JSONParser, URLEncodedParser) response_class = JSONResponse def _get_request_data(self, request): - if request.method not in ["GET", "DELETE"]: + if request.method not in ['GET', 'DELETE']: body = request.body - content_type = request.META.get("CONTENT_TYPE") + content_type = request.META.get('CONTENT_TYPE') if not content_type: - raise ValueError("content_type is required") + raise ValueError('content_type is required') for parser in self.request_parsers: if content_type.startswith(parser.content_type): break @@ -83,15 +84,15 @@ def response(self, data): return self.response_class.response(data) def success(self, data=None): - return self.response({"error": None, "data": data}) + return self.response({'error': None, 'data': data}) - def error(self, msg="error", err="error"): - return self.response({"error": err, "data": msg}) + def error(self, msg='error', err='error'): + return self.response({'error': err, 'data': msg}) - def extract_errors(self, errors, key="field"): + def extract_errors(self, errors, key='field'): if isinstance(errors, dict): if not errors: - return key, "Invalid field" + return key, 'Invalid field' key = list(errors.keys())[0] return self.extract_errors(errors.pop(key), key) elif isinstance(errors, list): @@ -101,14 +102,14 @@ def extract_errors(self, errors, key="field"): def invalid_serializer(self, serializer): key, error = self.extract_errors(serializer.errors) - if key == "non_field_errors": + if key == 'non_field_errors': msg = error else: - msg = f"{key}: {error}" - return self.error(err=f"invalid-{key}", msg=msg) + msg = f'{key}: {error}' + return self.error(err=f'invalid-{key}', msg=msg) def server_error(self): - return self.error(err="server-error", msg="server error") + return self.error(err='server-error', msg='server error') def paginate_data(self, request, query_set, object_serializer=None): """ @@ -118,25 +119,24 @@ def paginate_data(self, request, query_set, object_serializer=None): :return: """ try: - limit = int(request.GET.get("limit", "10")) + limit = int(request.GET.get('limit', '10')) except ValueError: limit = 10 if limit < 0 or limit > 250: limit = 10 try: - offset = int(request.GET.get("offset", "0")) + offset = int(request.GET.get('offset', '0')) except ValueError: offset = 0 if offset < 0: offset = 0 - results = query_set[offset:offset + limit] + results = query_set[offset : offset + limit] if object_serializer: count = query_set.count() results = object_serializer(results, many=True).data else: count = query_set.count() - data = {"results": results, - "total": count} + data = {'results': results, 'total': count} return data def dispatch(self, request, *args, **kwargs): @@ -144,13 +144,13 @@ def dispatch(self, request, *args, **kwargs): try: request.data = self._get_request_data(self.request) except ValueError as e: - return self.error(err="invalid-request", msg=str(e)) + return self.error(err='invalid-request', msg=str(e)) try: return super(APIView, self).dispatch(request, *args, **kwargs) except APIError as e: - ret = {"msg": e.msg} + ret = {'msg': e.msg} if e.err: - ret["err"] = e.err + ret['err'] = e.err return self.error(**ret) except Exception as e: logger.exception(e) @@ -169,6 +169,7 @@ def validate_serializer(serializer): def post(self, request): return self.success(request.data) """ + def validate(view_method): @functools.wraps(view_method) def handle(*args, **kwargs): diff --git a/utils/api/tests.py b/utils/api/tests.py index 0f0a79b84..c2f8e9d1a 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -8,9 +8,19 @@ class APITestCase(TestCase): client_class = APIClient - def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, - problem_permission=ProblemPermission.NONE): - user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission) + def create_user( + self, + username, + password, + admin_type=AdminType.REGULAR_USER, + login=True, + problem_permission=ProblemPermission.NONE, + ): + user = User.objects.create( + username=username, + admin_type=admin_type, + problem_permission=problem_permission, + ) user.set_password(password) UserProfile.objects.create(user=user) user.save() @@ -18,23 +28,34 @@ def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, log self.client.login(username=username, password=password) return user - def create_admin(self, username="admin", password="admin", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, - problem_permission=ProblemPermission.OWN, - login=login) - - def create_super_admin(self, username="root", password="root", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, - problem_permission=ProblemPermission.ALL, login=login) + def create_admin(self, username='admin', password='admin', login=True): + return self.create_user( + username=username, + password=password, + admin_type=AdminType.ADMIN, + problem_permission=ProblemPermission.OWN, + login=login, + ) + + def create_super_admin(self, username='root', password='root', login=True): + return self.create_user( + username=username, + password=password, + admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL, + login=login, + ) def reverse(self, url_name, *args, **kwargs): return reverse(url_name, *args, **kwargs) def assertSuccess(self, response): - if not response.data["error"] is None: - raise AssertionError("response with errors, response: " + str(response.data)) + if response.data['error'] is not None: + raise AssertionError( + 'response with errors, response: ' + str(response.data) + ) def assertFailed(self, response, msg=None): - self.assertTrue(response.data["error"] is not None) + self.assertTrue(response.data['error'] is not None) if msg: - self.assertEqual(response.data["data"], msg) + self.assertEqual(response.data['data'], msg) diff --git a/utils/captcha/__init__.py b/utils/captcha/__init__.py index 8b8375cdd..c94dd8096 100644 --- a/utils/captcha/__init__.py +++ b/utils/captcha/__init__.py @@ -24,8 +24,8 @@ def __init__(self, request): 初始化,设置各种属性 """ self.django_request = request - self.session_key = "_django_captcha_key" - self.captcha_expires_time = "_django_captcha_expires_time" + self.session_key = '_django_captcha_key' + self.captcha_expires_time = '_django_captcha_expires_time' # 验证码图片尺寸 self.img_width = 90 @@ -50,20 +50,33 @@ def _make_code(self): """ 生成随机数或随机字符串 """ - string = random.sample("abcdefghkmnpqrstuvwxyzABCDEFGHGKMNOPQRSTUVWXYZ23456789", 4) - self._set_answer("".join(string)) + string = random.sample( + 'abcdefghkmnpqrstuvwxyzABCDEFGHGKMNOPQRSTUVWXYZ23456789', 4 + ) + self._set_answer(''.join(string)) return string def get(self): """ 生成验证码图片,返回值为图片的bytes """ - background = (random.randrange(200, 255), random.randrange(200, 255), random.randrange(200, 255)) - code_color = (random.randrange(0, 50), random.randrange(0, 50), random.randrange(0, 50), 255) + background = ( + random.randrange(200, 255), + random.randrange(200, 255), + random.randrange(200, 255), + ) + code_color = ( + random.randrange(0, 50), + random.randrange(0, 50), + random.randrange(0, 50), + 255, + ) - font_path = os.path.join(os.path.normpath(os.path.dirname(__file__)), "timesbi.ttf") + font_path = os.path.join( + os.path.normpath(os.path.dirname(__file__)), 'timesbi.ttf' + ) - image = Image.new("RGB", (self.img_width, self.img_height), background) + image = Image.new('RGB', (self.img_width, self.img_height), background) code = self._make_code() font_size = self._get_font_size(code) draw = ImageDraw.Draw(image) @@ -75,19 +88,21 @@ def get(self): # 字符y坐标 y = random.randrange(1, 7) # 随机字符大小 - font = ImageFont.truetype(font_path.replace("\\", "/"), font_size + random.randrange(-3, 7)) + font = ImageFont.truetype( + font_path.replace('\\', '/'), font_size + random.randrange(-3, 7) + ) draw.text((x, y), i, font=font, fill=code_color) # 随机化字符之间的距离 字符粘连可以降低识别率 x += font_size * random.randrange(6, 8) / 10 - self.django_request.session[self.session_key] = "".join(code) + self.django_request.session[self.session_key] = ''.join(code) return image def check(self, code): """ 检查用户输入的验证码是否正确 """ - _code = self.django_request.session.get(self.session_key) or "" + _code = self.django_request.session.get(self.session_key) or '' if not _code: return False expires_time = self.django_request.session.get(self.captcha_expires_time) or 0 diff --git a/utils/constants.py b/utils/constants.py index 749b20da7..f38a7360a 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,35 +2,35 @@ class Choices: @classmethod def choices(cls): d = cls.__dict__ - return [d[item] for item in d.keys() if not item.startswith("__")] + return [d[item] for item in d.keys() if not item.startswith('__')] class ContestType: - PUBLIC_CONTEST = "Public" - PASSWORD_PROTECTED_CONTEST = "Password Protected" + PUBLIC_CONTEST = 'Public' + PASSWORD_PROTECTED_CONTEST = 'Password Protected' class ContestStatus: - CONTEST_NOT_START = "1" - CONTEST_ENDED = "-1" - CONTEST_UNDERWAY = "0" + CONTEST_NOT_START = '1' + CONTEST_ENDED = '-1' + CONTEST_UNDERWAY = '0' class ContestRuleType(Choices): - ACM = "ACM" - OI = "OI" + ACM = 'ACM' + OI = 'OI' class CacheKey: - waiting_queue = "waiting_queue" - contest_rank_cache = "contest_rank_cache" - website_config = "website_config" + waiting_queue = 'waiting_queue' + contest_rank_cache = 'contest_rank_cache' + website_config = 'website_config' class Difficulty(Choices): - LOW = "Low" - MID = "Mid" - HIGH = "High" + LOW = 'Low' + MID = 'Mid' + HIGH = 'High' -CONTEST_PASSWORD_SESSION_KEY = "contest_password" +CONTEST_PASSWORD_SESSION_KEY = 'contest_password' diff --git a/utils/management/commands/inituser.py b/utils/management/commands/inituser.py index e77693675..2efe823d0 100644 --- a/utils/management/commands/inituser.py +++ b/utils/management/commands/inituser.py @@ -6,39 +6,48 @@ class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument("--username", type=str) - parser.add_argument("--password", type=str) - parser.add_argument("--action", type=str) + parser.add_argument('--username', type=str) + parser.add_argument('--password', type=str) + parser.add_argument('--action', type=str) def handle(self, *args, **options): - username = options["username"] - password = options["password"] - action = options["action"] + username = options['username'] + password = options['password'] + action = options['action'] - if not(username and password and action): - self.stdout.write(self.style.ERROR("Invalid args")) + if not (username and password and action): + self.stdout.write(self.style.ERROR('Invalid args')) exit(1) - if action == "create_super_admin": + if action == 'create_super_admin': if User.objects.filter(id=1).exists(): - self.stdout.write(self.style.SUCCESS(f"User {username} exists, operation ignored")) + self.stdout.write( + self.style.SUCCESS(f'User {username} exists, operation ignored') + ) exit() - user = User.objects.create(username=username, admin_type=AdminType.SUPER_ADMIN, - problem_permission=ProblemPermission.ALL) + user = User.objects.create( + username=username, + admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL, + ) user.set_password(password) user.save() UserProfile.objects.create(user=user) - self.stdout.write(self.style.SUCCESS("User created")) - elif action == "reset": + self.stdout.write(self.style.SUCCESS('User created')) + elif action == 'reset': try: user = User.objects.get(username=username) user.set_password(password) user.save() - self.stdout.write(self.style.SUCCESS("Password is rested")) + self.stdout.write(self.style.SUCCESS('Password is rested')) except User.DoesNotExist: - self.stdout.write(self.style.ERROR(f"User {username} doesnot exist, operation ignored")) + self.stdout.write( + self.style.ERROR( + f'User {username} doesnot exist, operation ignored' + ) + ) exit(1) else: - raise ValueError("Invalid action") + raise ValueError('Invalid action') diff --git a/utils/migrate_data.py b/utils/migrate_data.py index b9b9bd306..73d307921 100644 --- a/utils/migrate_data.py +++ b/utils/migrate_data.py @@ -7,8 +7,8 @@ import hashlib from json.decoder import JSONDecodeError -sys.path.append("../") -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") +sys.path.append('../') +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings') django.setup() from django.conf import settings from account.models import User, UserProfile, AdminType, ProblemPermission @@ -17,14 +17,10 @@ admin_type_map = { 0: AdminType.REGULAR_USER, 1: AdminType.ADMIN, - 2: AdminType.SUPER_ADMIN + 2: AdminType.SUPER_ADMIN, } -languages_map = { - 1: "C", - 2: "C++", - 3: "Java" -} -email_regex = re.compile(r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)") +languages_map = {1: 'C', 2: 'C++', 3: 'Java'} +email_regex = re.compile(r'(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)') # pk -> name tags = {} @@ -37,25 +33,28 @@ def get_input_result(): while True: resp = input() - if resp not in ["yes", "no"]: - print("Please input yes or no") + if resp not in ['yes', 'no']: + print('Please input yes or no') continue - return resp == "yes" + return resp == 'yes' def set_problem_display_id_prefix(): while True: - print("Please input a prefix which will be used in all the imported problem's displayID") + print( + "Please input a prefix which will be used in all the imported problem's displayID" + ) print( "For example, if your input is 'old'(no quote), the problems' display id will be old1, old2, old3..\ninput:", - end="") + end='', + ) resp = input() if resp.strip(): return resp.strip() else: - print("Empty prefix detected, sure to do that? (yes/no)") + print('Empty prefix detected, sure to do that? (yes/no)') if get_input_result(): - return "" + return '' def get_stripped_output_md5(test_case_id, output_name): @@ -65,148 +64,158 @@ def get_stripped_output_md5(test_case_id, output_name): def get_test_case_score(test_case_id): - info_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, "info") + info_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, 'info') if not os.path.exists(info_path): return [] - with open(info_path, "r") as info_file: + with open(info_path, 'r') as info_file: info = json.load(info_file) test_case_score = [] need_rewrite = True - for test_case in info["test_cases"].values(): - if test_case.__contains__("stripped_output_md5"): + for test_case in info['test_cases'].values(): + if test_case.__contains__('stripped_output_md5'): need_rewrite = False - elif test_case.__contains__("striped_output_md5"): - test_case["stripped_output_md5"] = test_case.pop("striped_output_md5") + elif test_case.__contains__('striped_output_md5'): + test_case['stripped_output_md5'] = test_case.pop('striped_output_md5') else: - test_case["stripped_output_md5"] = get_stripped_output_md5(test_case_id, test_case["output_name"]) - test_case_score.append({"input_name": test_case["input_name"], - "output_name": test_case.get("output_name", "-"), - "score": 0}) + test_case['stripped_output_md5'] = get_stripped_output_md5( + test_case_id, test_case['output_name'] + ) + test_case_score.append( + { + 'input_name': test_case['input_name'], + 'output_name': test_case.get('output_name', '-'), + 'score': 0, + } + ) if need_rewrite: - with open(info_path, "w") as f: + with open(info_path, 'w') as f: f.write(json.dumps(info)) return test_case_score def import_users(): i = 0 - print("Find %d users in old data." % len(users.keys())) - print("import users now? (yes/no)") + print('Find %d users in old data.' % len(users.keys())) + print('import users now? (yes/no)') if get_input_result(): for data in users.values(): - if not email_regex.match(data["email"]): - print("%s will not be created due to invalid email: %s" % (data["username"], data["email"])) + if not email_regex.match(data['email']): + print( + '%s will not be created due to invalid email: %s' + % (data['username'], data['email']) + ) continue - data["username"] = data["username"].lower() - user, created = User.objects.get_or_create(username=data["username"]) + data['username'] = data['username'].lower() + user, created = User.objects.get_or_create(username=data['username']) if not created: - print("%s already exists, omitted" % user.username) + print('%s already exists, omitted' % user.username) continue - user.password = data["password"] - user.email = data["email"] - admin_type = admin_type_map[data["admin_type"]] + user.password = data['password'] + user.email = data['email'] + admin_type = admin_type_map[data['admin_type']] user.admin_type = admin_type if admin_type == AdminType.ADMIN: user.problem_permission = ProblemPermission.OWN elif admin_type == AdminType.SUPER_ADMIN: user.problem_permission = ProblemPermission.ALL user.save() - UserProfile.objects.create(user=user, real_name=data["real_name"]) + UserProfile.objects.create(user=user, real_name=data['real_name']) i += 1 - print("%s imported successfully" % user.username) - print("%d users have successfully imported\n" % i) + print('%s imported successfully' % user.username) + print('%d users have successfully imported\n' % i) def import_tags(): i = 0 - print("\nFind these tags in old data:") - print(", ".join(tags.values()), '\n') - print("import tags now? (yes/no)") + print('\nFind these tags in old data:') + print(', '.join(tags.values()), '\n') + print('import tags now? (yes/no)') if get_input_result(): for tagname in tags.values(): tag, created = ProblemTag.objects.get_or_create(name=tagname) if not created: - print("%s already exists, omitted" % tagname) + print('%s already exists, omitted' % tagname) else: - print("%s tag created successfully" % tagname) + print('%s tag created successfully' % tagname) i += 1 - print("%d tags have successfully imported\n" % i) + print('%d tags have successfully imported\n' % i) else: - print("Problem depends on problem_tags and users, exit..") + print('Problem depends on problem_tags and users, exit..') exit(1) def import_problems(): i = 0 - print("\nFind %d problems in old data" % len(problems)) + print('\nFind %d problems in old data' % len(problems)) prefix = set_problem_display_id_prefix() - print("import problems using prefix: %s? (yes/no)" % prefix) + print('import problems using prefix: %s? (yes/no)' % prefix) if get_input_result(): default_creator = User.objects.first() for data in problems: - data["_id"] = prefix + str(data.pop("id")) - if Problem.objects.filter(_id=data["_id"]).exists(): - print("%s has the same display_id with the db problem" % data["title"]) + data['_id'] = prefix + str(data.pop('id')) + if Problem.objects.filter(_id=data['_id']).exists(): + print('%s has the same display_id with the db problem' % data['title']) continue try: - creator_id = \ - User.objects.filter(username=users[data["created_by"]]["username"]).values_list("id", flat=True)[0] + creator_id = User.objects.filter( + username=users[data['created_by']]['username'] + ).values_list('id', flat=True)[0] except (User.DoesNotExist, IndexError): - print("The origin creator does not exist, set it to default_creator") + print('The origin creator does not exist, set it to default_creator') creator_id = default_creator.id - data["created_by_id"] = creator_id - data.pop("created_by") - data["difficulty"] = ProblemDifficulty.Mid - if data["spj_language"]: - data["spj_language"] = languages_map[data["spj_language"]] - data["samples"] = json.loads(data["samples"]) - data["languages"] = ["C", "C++"] - test_case_score = get_test_case_score(data["test_case_id"]) + data['created_by_id'] = creator_id + data.pop('created_by') + data['difficulty'] = ProblemDifficulty.Mid + if data['spj_language']: + data['spj_language'] = languages_map[data['spj_language']] + data['samples'] = json.loads(data['samples']) + data['languages'] = ['C', 'C++'] + test_case_score = get_test_case_score(data['test_case_id']) if not test_case_score: - print("%s test_case files don't exist, omitted" % data["title"]) + print("%s test_case files don't exist, omitted" % data['title']) continue - data["test_case_score"] = test_case_score - data["rule_type"] = ProblemRuleType.ACM - data["template"] = {} - data.pop("total_submit_number") - data.pop("total_accepted_number") - tag_ids = data.pop("tags") + data['test_case_score'] = test_case_score + data['rule_type'] = ProblemRuleType.ACM + data['template'] = {} + data.pop('total_submit_number') + data.pop('total_accepted_number') + tag_ids = data.pop('tags') problem = Problem.objects.create(**data) - problem.create_time = data["create_time"] + problem.create_time = data['create_time'] problem.save() for tag_id in tag_ids: tag, _ = ProblemTag.objects.get_or_create(name=tags[tag_id]) problem.tags.add(tag) i += 1 - print("%s imported successfully" % data["title"]) - print("%d problems have successfully imported" % i) + print('%s imported successfully' % data['title']) + print('%d problems have successfully imported' % i) -if __name__ == "__main__": +if __name__ == '__main__': if len(sys.argv) == 1: - print("Usage: python3 %s [old_data_path]" % sys.argv[0]) + print('Usage: python3 %s [old_data_path]' % sys.argv[0]) exit(0) data_path = sys.argv[1] if not os.path.isfile(data_path): - print("Data file does not exist") + print('Data file does not exist') exit(1) try: - with open(data_path, "r") as data_file: + with open(data_path, 'r') as data_file: old_data = json.load(data_file) except JSONDecodeError: print("Data file format error, ensure it's a valid json file!") exit(1) - print("Read old data successfully.\n") + print('Read old data successfully.\n') for obj in old_data: - if obj["model"] == "problem.problemtag": - tags[obj["pk"]] = obj["fields"]["name"] - elif obj["model"] == "account.user": - users[obj["pk"]] = obj["fields"] - elif obj["model"] == "problem.problem": - obj["fields"]["id"] = obj["pk"] - problems.append(obj["fields"]) + if obj['model'] == 'problem.problemtag': + tags[obj['pk']] = obj['fields']['name'] + elif obj['model'] == 'account.user': + users[obj['pk']] = obj['fields'] + elif obj['model'] == 'problem.problem': + obj['fields']['id'] = obj['pk'] + problems.append(obj['fields']) import_users() import_tags() import_problems() diff --git a/utils/models.py b/utils/models.py index 7577f1fab..e5f6936c0 100644 --- a/utils/models.py +++ b/utils/models.py @@ -7,4 +7,4 @@ class RichTextField(models.TextField): def get_prep_value(self, value): with XSSHtml() as parser: - return parser.clean(value or "") + return parser.clean(value or '') diff --git a/utils/serializers.py b/utils/serializers.py index c543c56f1..2de377292 100644 --- a/utils/serializers.py +++ b/utils/serializers.py @@ -5,7 +5,7 @@ class InvalidLanguage(serializers.ValidationError): def __init__(self, name): - super().__init__(detail=f"{name} is not a valid language") + super().__init__(detail=f'{name} is not a valid language') class LanguageNameChoiceField(serializers.CharField): diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 84e14fd28..05ad5f4f7 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -9,51 +9,60 @@ from envelopes import Envelope -def rand_str(length=32, type="lower_hex"): +def rand_str(length=32, type='lower_hex'): """ 生成指定长度的随机字符串或者数字, 可以用于密钥等安全场景 :param length: 字符串或者数字的长度 :param type: str 代表随机字符串,num 代表随机数字 :return: 字符串 """ - if type == "str": - return get_random_string(length, allowed_chars="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") - elif type == "lower_str": - return get_random_string(length, allowed_chars="abcdefghijklmnopqrstuvwxyz0123456789") - elif type == "lower_hex": - return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef") + if type == 'str': + return get_random_string( + length, + allowed_chars='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789', + ) + elif type == 'lower_str': + return get_random_string( + length, allowed_chars='abcdefghijklmnopqrstuvwxyz0123456789' + ) + elif type == 'lower_hex': + return random.choice('123456789abcdef') + get_random_string( + length - 1, allowed_chars='0123456789abcdef' + ) else: - return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789") + return random.choice('123456789') + get_random_string( + length - 1, allowed_chars='0123456789' + ) def build_query_string(kv_data, ignore_none=True): # {"a": 1, "b": "test"} -> "?a=1&b=test" - query_string = "" + query_string = '' for k, v in kv_data.items(): if ignore_none is True and kv_data[k] is None: continue - if query_string != "": - query_string += "&" + if query_string != '': + query_string += '&' else: - query_string = "?" - query_string += (k + "=" + str(v)) + query_string = '?' + query_string += k + '=' + str(v) return query_string def img2base64(img): with BytesIO() as buf: - img.save(buf, "gif") + img.save(buf, 'gif') buf_str = buf.getvalue() - img_prefix = "data:image/png;base64," - b64_str = img_prefix + b64encode(buf_str).decode("utf-8") + img_prefix = 'data:image/png;base64,' + b64_str = img_prefix + b64encode(buf_str).decode('utf-8') return b64_str -def datetime2str(value, format="iso-8601"): - if format.lower() == "iso-8601": +def datetime2str(value, format='iso-8601'): + if format.lower() == 'iso-8601': value = value.isoformat() - if value.endswith("+00:00"): - value = value[:-6] + "Z" + if value.endswith('+00:00'): + value = value[:-6] + 'Z' return value return value.strftime(format) @@ -62,29 +71,34 @@ def timestamp2utcstr(value): return datetime.datetime.utcfromtimestamp(value).isoformat() -def natural_sort_key(s, _nsre=re.compile(r"(\d+)")): - return [int(text) if text.isdigit() else text.lower() - for text in re.split(_nsre, s)] +def natural_sort_key(s, _nsre=re.compile(r'(\d+)')): + return [ + int(text) if text.isdigit() else text.lower() for text in re.split(_nsre, s) + ] def send_email(smtp_config, from_name, to_email, to_name, subject, content): - envelope = Envelope(from_addr=(smtp_config["email"], from_name), - to_addr=(to_email, to_name), - subject=subject, - html_body=content) - return envelope.send(smtp_config["server"], - login=smtp_config["email"], - password=smtp_config["password"], - port=smtp_config["port"], - tls=smtp_config["tls"]) - - -def get_env(name, default=""): + envelope = Envelope( + from_addr=(smtp_config['email'], from_name), + to_addr=(to_email, to_name), + subject=subject, + html_body=content, + ) + return envelope.send( + smtp_config['server'], + login=smtp_config['email'], + password=smtp_config['password'], + port=smtp_config['port'], + tls=smtp_config['tls'], + ) + + +def get_env(name, default=''): return os.environ.get(name, default) def DRAMATIQ_WORKER_ARGS(time_limit=3600_000, max_retries=0, max_age=7200_000): - return {"max_retries": max_retries, "time_limit": time_limit, "max_age": max_age} + return {'max_retries': max_retries, 'time_limit': time_limit, 'max_age': max_age} def check_is_id(value): diff --git a/utils/throttling.py b/utils/throttling.py index bab1184a7..62cae16b9 100644 --- a/utils/throttling.py +++ b/utils/throttling.py @@ -5,6 +5,7 @@ class TokenBucket: """ 注意:对于单个key的操作不是线程安全的 """ + def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): """ :param capacity: 最大容量 @@ -18,8 +19,8 @@ def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): self._default_capacity = default_capacity self._redis_conn = redis_conn - self._last_capacity_key = "last_capacity" - self._last_timestamp_key = "last_timestamp" + self._last_capacity_key = 'last_capacity' + self._last_timestamp_key = 'last_timestamp' def _init_key(self): self._last_capacity = self._default_capacity diff --git a/utils/urls.py b/utils/urls.py index 7e0128e5c..bd99df08f 100644 --- a/utils/urls.py +++ b/utils/urls.py @@ -3,6 +3,6 @@ from .views import SimditorImageUploadAPIView, SimditorFileUploadAPIView urlpatterns = [ - url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image"), - url(r"^upload_file/?$", SimditorFileUploadAPIView.as_view(), name="upload_file") + url(r'^upload_image/?$', SimditorImageUploadAPIView.as_view(), name='upload_image'), + url(r'^upload_file/?$', SimditorFileUploadAPIView.as_view(), name='upload_file'), ] diff --git a/utils/views.py b/utils/views.py index ce30c717c..acb1e3abb 100644 --- a/utils/views.py +++ b/utils/views.py @@ -14,34 +14,34 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView): def post(self, request): form = ImageUploadForm(request.POST, request.FILES) if form.is_valid(): - img = form.cleaned_data["image"] + img = form.cleaned_data['image'] else: - return self.response({ - "success": False, - "msg": "Upload failed", - "file_path": ""}) + return self.response( + {'success': False, 'msg': 'Upload failed', 'file_path': ''} + ) suffix = os.path.splitext(img.name)[-1].lower() - if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: - return self.response({ - "success": False, - "msg": "Unsupported file format", - "file_path": ""}) + if suffix not in ['.gif', '.jpg', '.jpeg', '.bmp', '.png']: + return self.response( + {'success': False, 'msg': 'Unsupported file format', 'file_path': ''} + ) img_name = rand_str(10) + suffix try: - with open(os.path.join(settings.UPLOAD_DIR, img_name), "wb") as imgFile: + with open(os.path.join(settings.UPLOAD_DIR, img_name), 'wb') as imgFile: for chunk in img: imgFile.write(chunk) except IOError as e: logger.error(e) - return self.response({ - "success": False, - "msg": "Upload Error", - "file_path": ""}) - return self.response({ - "success": True, - "msg": "Success", - "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) + return self.response( + {'success': False, 'msg': 'Upload Error', 'file_path': ''} + ) + return self.response( + { + 'success': True, + 'msg': 'Success', + 'file_path': f'{settings.UPLOAD_PREFIX}/{img_name}', + } + ) class SimditorFileUploadAPIView(CSRFExemptAPIView): @@ -50,26 +50,24 @@ class SimditorFileUploadAPIView(CSRFExemptAPIView): def post(self, request): form = FileUploadForm(request.POST, request.FILES) if form.is_valid(): - file = form.cleaned_data["file"] + file = form.cleaned_data['file'] else: - return self.response({ - "success": False, - "msg": "Upload failed" - }) + return self.response({'success': False, 'msg': 'Upload failed'}) suffix = os.path.splitext(file.name)[-1].lower() file_name = rand_str(10) + suffix try: - with open(os.path.join(settings.UPLOAD_DIR, file_name), "wb") as f: + with open(os.path.join(settings.UPLOAD_DIR, file_name), 'wb') as f: for chunk in file: f.write(chunk) except IOError as e: logger.error(e) - return self.response({ - "success": False, - "msg": "Upload Error"}) - return self.response({ - "success": True, - "msg": "Success", - "file_path": f"{settings.UPLOAD_PREFIX}/{file_name}", - "file_name": file.name}) + return self.response({'success': False, 'msg': 'Upload Error'}) + return self.response( + { + 'success': True, + 'msg': 'Success', + 'file_path': f'{settings.UPLOAD_PREFIX}/{file_name}', + 'file_name': file.name, + } + ) diff --git a/utils/xss_filter.py b/utils/xss_filter.py index fe4f7aa8c..904f29a1c 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -31,19 +31,63 @@ class XSSHtml(HTMLParser): - allow_tags = ['a', 'img', 'br', 'strong', 'b', 'code', 'pre', - 'p', 'div', 'em', 'span', 'h1', 'h2', 'h3', 'h4', - 'h5', 'h6', 'blockquote', 'ul', 'ol', 'tr', 'th', 'td', - 'hr', 'li', 'u', 'embed', 's', 'table', 'thead', 'tbody', - 'caption', 'small', 'q', 'sup', 'sub', 'font'] - common_attrs = ["style", "class", "name"] - nonend_tags = ["img", "hr", "br", "embed"] + allow_tags = [ + 'a', + 'img', + 'br', + 'strong', + 'b', + 'code', + 'pre', + 'p', + 'div', + 'em', + 'span', + 'h1', + 'h2', + 'h3', + 'h4', + 'h5', + 'h6', + 'blockquote', + 'ul', + 'ol', + 'tr', + 'th', + 'td', + 'hr', + 'li', + 'u', + 'embed', + 's', + 'table', + 'thead', + 'tbody', + 'caption', + 'small', + 'q', + 'sup', + 'sub', + 'font', + ] + common_attrs = ['style', 'class', 'name'] + nonend_tags = ['img', 'hr', 'br', 'embed'] tags_own_attrs = { - "img": ["src", "width", "height", "alt", "align"], - "a": ["href", "target", "rel", "title"], - "embed": ["src", "width", "height", "type", "allowfullscreen", "loop", "play", "wmode", "menu"], - "table": ["border", "cellpadding", "cellspacing"], - "font": ["color"] + 'img': ['src', 'width', 'height', 'alt', 'align'], + 'a': ['href', 'target', 'rel', 'title'], + 'embed': [ + 'src', + 'width', + 'height', + 'type', + 'allowfullscreen', + 'loop', + 'play', + 'wmode', + 'menu', + ], + 'table': ['border', 'cellpadding', 'cellspacing'], + 'font': ['color'], } def __init__(self, allows=[]): @@ -86,13 +130,13 @@ def handle_starttag(self, tag, attrs): attdict[attr[0]] = attr[1] attdict = self._wash_attr(attdict, tag) - if hasattr(self, "node_%s" % tag): - attdict = getattr(self, "node_%s" % tag)(attdict) + if hasattr(self, 'node_%s' % tag): + attdict = getattr(self, 'node_%s' % tag)(attdict) else: attdict = self.node_default(attdict) attrs = [] - for (key, value) in attdict.items(): + for key, value in attdict.items(): attrs.append('%s="%s"' % (key, self._htmlspecialchars(value))) attrs = (' ' + ' '.join(attrs)) if attrs else '' self.result.append('<' + tag + attrs + end_diagonal + '>') @@ -107,11 +151,11 @@ def handle_data(self, data): def handle_entityref(self, name): if name.isalpha(): - self.result.append("&%s;" % name) + self.result.append('&%s;' % name) def handle_charref(self, name): if name.isdigit(): - self.result.append("&#%s;" % name) + self.result.append('&#%s;' % name) def node_default(self, attrs): attrs = self._common_attr(attrs) @@ -119,44 +163,45 @@ def node_default(self, attrs): def node_a(self, attrs): attrs = self._common_attr(attrs) - attrs = self._get_link(attrs, "href") - attrs = self._set_attr_default(attrs, "target", "_blank") - attrs = self._limit_attr(attrs, { - "target": ["_blank", "_self"] - }) + attrs = self._get_link(attrs, 'href') + attrs = self._set_attr_default(attrs, 'target', '_blank') + attrs = self._limit_attr(attrs, {'target': ['_blank', '_self']}) return attrs def node_embed(self, attrs): attrs = self._common_attr(attrs) - attrs = self._get_link(attrs, "src") - attrs = self._limit_attr(attrs, { - "type": ["application/x-shockwave-flash"], - "wmode": ["transparent", "window", "opaque"], - "play": ["true", "false"], - "loop": ["true", "false"], - "menu": ["true", "false"], - "allowfullscreen": ["true", "false"] - }) - attrs["allowscriptaccess"] = "never" - attrs["allownetworking"] = "none" + attrs = self._get_link(attrs, 'src') + attrs = self._limit_attr( + attrs, + { + 'type': ['application/x-shockwave-flash'], + 'wmode': ['transparent', 'window', 'opaque'], + 'play': ['true', 'false'], + 'loop': ['true', 'false'], + 'menu': ['true', 'false'], + 'allowfullscreen': ['true', 'false'], + }, + ) + attrs['allowscriptaccess'] = 'never' + attrs['allownetworking'] = 'none' return attrs def _true_url(self, url): - prog = re.compile(r"(^(http|https|ftp)://.+)|(^/)", re.I | re.S) + prog = re.compile(r'(^(http|https|ftp)://.+)|(^/)', re.I | re.S) if prog.match(url): return url else: - return "http://%s" % url + return 'http://%s' % url def _true_style(self, style): if style: - style = re.sub(r"(\\|&#|/\*|\*/)", "_", style) - style = re.sub(r"e.*x.*p.*r.*e.*s.*s.*i.*o.*n", "_", style) + style = re.sub(r'(\\|&#|/\*|\*/)', '_', style) + style = re.sub(r'e.*x.*p.*r.*e.*s.*s.*i.*o.*n', '_', style) return style def _get_style(self, attrs): - if "style" in attrs: - attrs["style"] = self._true_style(attrs.get("style")) + if 'style' in attrs: + attrs['style'] = self._true_style(attrs.get('style')) return attrs def _get_link(self, attrs, name): @@ -185,24 +230,28 @@ def _set_attr_default(self, attrs, name, default=''): return attrs def _limit_attr(self, attrs, limit={}): - for (key, value) in limit.items(): + for key, value in limit.items(): if key in attrs and attrs[key] not in value: del attrs[key] return attrs def _htmlspecialchars(self, html): - return html.replace("<", "<") \ - .replace(">", ">") \ - .replace('"', """) \ - .replace("'", "'") + return ( + html.replace('<', '<') + .replace('>', '>') + .replace('"', '"') + .replace("'", ''') + ) -if "__main__" == __name__: +if '__main__' == __name__: with XSSHtml() as parser: - ret = parser.clean("""

+ ret = parser.clean( + """

>M MM

- """) + """ + ) print(ret)