Coverage for src / competitive_verifier / oj / problem.py: 82%
301 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
1import glob
2import json
3import os
4import pathlib
5import posixpath
6import re
7import subprocess
8import sys
9import urllib.parse
10import zipfile
11from abc import abstractmethod
12from collections.abc import Iterable, Iterator
13from dataclasses import dataclass
14from io import BytesIO
15from logging import getLogger
16from typing import ClassVar, Optional, TypeVar
18import requests
20from competitive_verifier import config
21from competitive_verifier.log import GitHubMessageParams
22from competitive_verifier.models import (
23 Problem,
24 TestCaseData,
25 TestCaseFile,
26 TestCaseProvider,
27)
29logger = getLogger(__name__)
32class NotLoggedInError(RuntimeError):
33 pass
36class _BaseProblem(Problem):
37 def iter_system_cases(self) -> Iterator[TestCaseFile]:
38 return iter_testcases(directory=self.test_directory)
40 def download_system_cases(self) -> Iterable[TestCaseData] | bool:
41 test_directory = self.test_directory
43 if test_directory.exists() and any(test_directory.iterdir()):
44 logger.info("download:already exists: %s", self.url)
45 return True
47 self.problem_directory.mkdir(parents=True, exist_ok=True)
49 samples = list(self._download_cases())
51 # Check samples
52 if not samples: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 logger.error(
54 "Sample not found",
55 extra={"github": GitHubMessageParams()},
56 )
57 return False
59 # write samples to files
60 save_testcases(samples, directory=test_directory)
61 return samples
63 @abstractmethod
64 def _download_cases(self) -> Iterable[TestCaseData]: ...
67class LibraryCheckerProblem(Problem):
68 checker_exe_name: ClassVar[str] = (
69 "checker.exe" if sys.platform == "win32" else "checker"
70 )
72 def __init__(self, *, problem_id: str):
73 self.problem_id = problem_id
74 self._source_directory = None
76 def __hash__(self) -> int:
77 return hash((self.problem_id, self.repo_path))
79 def __eq__(self, value: object) -> bool:
80 if not isinstance(value, LibraryCheckerProblem): 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true
81 return False
82 return self.problem_id == value.problem_id and self.repo_path == value.repo_path
84 @property
85 def repo_path(self):
86 return config.get_cache_dir() / "library-checker-problems"
88 def iter_system_cases(self) -> Iterator[TestCaseFile]:
89 inputs: dict[str, pathlib.Path] = {}
90 outputs: dict[str, pathlib.Path] = {}
91 for path in self.source_directory.glob("in/*.in"):
92 inputs[path.stem] = path
93 for path in self.source_directory.glob("out/*.out"):
94 outputs[path.stem] = path
95 return merge_testcase_files(inputs, outputs)
97 def download_system_cases(self) -> bool:
98 self.problem_directory.mkdir(parents=True, exist_ok=True)
99 self.generate_test_cases()
100 return True
102 @property
103 def checker(self) -> pathlib.Path | None:
104 return self.source_directory / self.checker_exe_name
106 def generate_test_cases(self) -> None:
107 self.update_cloned_repository()
108 path = self.repo_path
110 spec = str(self.source_directory / "info.toml")
111 command = [sys.executable, str(path / "generate.py"), spec]
112 logger.info("$ %s", " ".join(command))
113 try:
114 subprocess.check_call(command, stdout=sys.stderr, stderr=sys.stderr)
115 except subprocess.CalledProcessError:
116 logger.exception(
117 "the generate.py failed: check https://github.com/yosupo06/library-checker-problems/issues",
118 extra={"github": GitHubMessageParams()},
119 )
120 raise
122 @property
123 def source_directory(self):
124 if self._source_directory is None:
125 problem_id = self.problem_id
126 info_tomls = list(
127 self.repo_path.glob(f"**/{glob.escape(problem_id)}/info.toml")
128 )
129 if len(info_tomls) != 1: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true
130 raise RuntimeError(f"the problem {problem_id!r} not found or broken")
131 self._source_directory = info_tomls[0].parent
132 return self._source_directory
134 @property
135 def url(self) -> str:
136 return f"https://judge.yosupo.jp/problem/{self.problem_id}"
138 @classmethod
139 def from_url(cls, url: str) -> Optional["LibraryCheckerProblem"]:
140 # example: https://judge.yosupo.jp/problem/unionfind
141 result = urllib.parse.urlparse(url)
142 if result.scheme in ("", "http", "https") and result.netloc in (
143 "judge.yosupo.jp",
144 "old.yosupo.jp",
145 ):
146 m = re.match(r"/problem/(\w+)/?", result.path)
147 if m: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true
148 return cls(problem_id=m.group(1))
149 return None
151 _is_repository_updated = False
153 def update_cloned_repository(self) -> None:
154 if self._is_repository_updated: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 return
157 try:
158 subprocess.check_call(
159 ["git", "--version"], # noqa: S607
160 stdout=sys.stderr,
161 stderr=sys.stderr,
162 )
163 except FileNotFoundError:
164 logger.exception(
165 "git command not found",
166 exc_info=False,
167 extra={"github": GitHubMessageParams()},
168 )
169 raise
171 path = self.repo_path
172 if not path.exists(): 172 ↛ 183line 172 didn't jump to line 183 because the condition on line 172 was always true
173 # init the problem repository
174 url = "https://github.com/yosupo06/library-checker-problems"
175 logger.info("$ git clone %s %s", url, path)
176 subprocess.check_call(
177 ["git", "clone", url, str(path)], # noqa: S607
178 stdout=sys.stderr,
179 stderr=sys.stderr,
180 )
181 else:
182 # sync the problem repository
183 logger.info("$ git -C %s pull", path)
184 subprocess.check_call(
185 ["git", "-C", str(path), "pull"], # noqa: S607
186 stdout=sys.stderr,
187 stderr=sys.stderr,
188 )
190 self._is_repository_updated = True
193class _YukicoderProblemNo(int):
194 def __new__(cls, value: int):
195 return super().__new__(cls, value)
197 def __str__(self) -> str:
198 return "no/" + super().__str__()
201class _YukicoderProblemId(int):
202 def __new__(cls, value: int):
203 return super().__new__(cls, value)
206class YukicoderProblem(_BaseProblem):
207 problem: _YukicoderProblemNo | _YukicoderProblemId
209 def __init__(self, *, problem_no: int | None = None, problem_id: int | None = None):
210 if problem_no is not None:
211 self.problem = _YukicoderProblemNo(problem_no)
212 elif problem_id is not None: 212 ↛ 215line 212 didn't jump to line 215 because the condition on line 212 was always true
213 self.problem = _YukicoderProblemId(problem_id)
214 else:
215 raise ValueError("Needs problem_no or problem_id")
217 def _download_cases(self) -> list[TestCaseData]:
218 """Download yukicoder problem.
220 Raises:
221 NotLoggedInError: If the `cargo metadata` command fails
222 """
223 headers: dict[str, str] | None = None
224 if yukicoder_token := os.environ.get("YUKICODER_TOKEN"):
225 headers = {"Authorization": f"Bearer {yukicoder_token}"}
227 if not self._is_logged_in(headers=headers):
228 raise NotLoggedInError("Required: $YUKICODER_TOKEN environment variable")
229 url = f"{self.url}/testcase.zip"
230 resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10)
232 with zipfile.ZipFile(BytesIO(resp.content)) as fh:
233 inputs: dict[str, bytes] = {}
234 outputs: dict[str, bytes] = {}
235 for filename in fh.namelist():
236 if filename.endswith("/"):
237 continue
238 file = fh.read(filename)
239 path = pathlib.Path(filename)
240 if filename.startswith("test_in/"):
241 inputs[path.stem] = file
242 elif filename.startswith("test_out/"):
243 outputs[path.stem] = file
244 return [
245 TestCaseData(name=name, input_data=i, output_data=o)
246 for name, i, o in enumerate_inouts(inputs, outputs)
247 ]
249 @property
250 def url(self) -> str:
251 return f"https://yukicoder.me/problems/{self.problem}"
253 @classmethod
254 def from_url(cls, url: str) -> Optional["YukicoderProblem"]:
255 # example: https://yukicoder.me/problems/no/499
256 # example: http://yukicoder.me/problems/1476
257 result = urllib.parse.urlparse(url)
258 dirname, basename = posixpath.split(_normpath(result.path))
259 if result.scheme in ("", "http", "https") and result.netloc == "yukicoder.me":
260 try:
261 n = int(basename)
262 except ValueError:
263 pass
264 else:
265 if dirname == "/problems/no":
266 return cls(problem_no=n)
267 if dirname == "/problems":
268 return cls(problem_id=n)
269 return None
271 def _is_logged_in(self, *, headers: dict[str, str] | None = None) -> bool:
272 url = "https://yukicoder.me"
273 resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10)
274 resp.raise_for_status()
275 return "login-btn" not in str(resp.content)
278class AOJProblem(_BaseProblem):
279 def __init__(self, *, problem_id: str):
280 self.problem_id = problem_id
282 def _download_cases(self) -> Iterable[TestCaseData]:
283 return AOJProblem.download_cases(self.problem_id)
285 @staticmethod
286 def download_cases(problem_id: str) -> Iterable[TestCaseData]:
287 # get header
288 # reference: http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2F%7BproblemId%7D%2Fheader_GET
289 url = f"https://judgedat.u-aizu.ac.jp/testcases/{problem_id}/header"
290 resp = requests.get(url, allow_redirects=True, timeout=10)
291 resp.raise_for_status()
292 header_res = json.loads(resp.text)
294 # get testcases via the official API
295 for header in header_res["headers"]:
296 # NOTE: the endpoints are not same to http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2F%7BproblemId%7D%2F%7Bserial%7D_GET since the json API often says "..... (terminated because of the limitation)"
297 # NOTE: even when using https://judgedat.u-aizu.ac.jp/testcases/PROBLEM_ID/SERIAL, there is the 1G limit (see https://twitter.com/beet_aizu/status/1194947611100188672)
298 serial = header["serial"]
299 url = f"https://judgedat.u-aizu.ac.jp/testcases/{problem_id}/{serial}"
301 resp_in = requests.get(url + "/in", allow_redirects=True, timeout=10)
302 resp_in.raise_for_status()
303 resp_out = requests.get(url + "/out", allow_redirects=True, timeout=10)
304 resp_out.raise_for_status()
306 yield TestCaseData(
307 header["name"],
308 resp_in.content,
309 resp_out.content,
310 )
312 @property
313 def url(self) -> str:
314 return f"http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id={self.problem_id}"
316 @classmethod
317 def from_url(cls, url: str) -> Optional["AOJProblem"]:
318 result = urllib.parse.urlparse(url)
320 # example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1169
321 # example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_1_A&lang=jp
322 querystring = urllib.parse.parse_qs(result.query)
323 if (
324 result.scheme in ("", "http", "https")
325 and result.netloc == "judge.u-aizu.ac.jp"
326 and _normpath(result.path) == "/onlinejudge/description.jsp"
327 and querystring.get("id")
328 and len(querystring["id"]) == 1
329 ):
330 (n,) = querystring["id"]
331 return cls(problem_id=n)
333 # example: https://onlinejudge.u-aizu.ac.jp/challenges/sources/JAG/Prelim/2881
334 # example: https://onlinejudge.u-aizu.ac.jp/courses/library/4/CGL/3/CGL_3_B
335 m = re.match(
336 r"^/(challenges|courses)/(sources|library/\d+|lesson/\d+)/(\w+)/(\w+)/(\w+)$",
337 _normpath(result.path),
338 )
339 if (
340 result.scheme in ("", "http", "https")
341 and result.netloc == "onlinejudge.u-aizu.ac.jp"
342 and m
343 ):
344 n = m.group(5)
345 return cls(problem_id=n)
347 # example: https://onlinejudge.u-aizu.ac.jp/problems/0423
348 # example: https://onlinejudge.u-aizu.ac.jp/problems/CGL_3_B
349 m = re.match(r"^/problems/(\w+)$", _normpath(result.path))
350 if (
351 result.scheme in ("", "http", "https")
352 and result.netloc == "onlinejudge.u-aizu.ac.jp"
353 and m
354 ):
355 n = m.group(1)
356 return cls(problem_id=n)
358 return None
361class AOJArenaProblem(_BaseProblem):
362 def __init__(self, *, arena_id: str, alphabet: str):
363 if len(alphabet) != 1 or not alphabet.isupper(): 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true
364 raise ValueError(arena_id, alphabet)
365 self.arena_id = arena_id
366 self.alphabet = alphabet
368 self._problem_id: str | None = None
370 def get_problem_id(self) -> str:
371 if self._problem_id is None: 371 ↛ 383line 371 didn't jump to line 383 because the condition on line 371 was always true
372 url = f"https://judgeapi.u-aizu.ac.jp/arenas/{self.arena_id}/problems"
373 resp = requests.get(url, allow_redirects=True, timeout=10)
374 resp.raise_for_status()
375 problems = json.loads(resp.text)
376 for problem in problems: 376 ↛ 382line 376 didn't jump to line 382 because the loop on line 376 didn't complete
377 if problem["id"] == self.alphabet: 377 ↛ 376line 377 didn't jump to line 376 because the condition on line 377 was always true
378 p = problem["problemId"]
379 logger.debug("problem: %s", p)
380 self._problem_id = p
381 return p
382 raise ValueError("Problem is not found.")
383 return self._problem_id
385 def _download_cases(self) -> Iterable[TestCaseData]:
386 return AOJProblem.download_cases(self.get_problem_id())
388 @property
389 def url(self) -> str:
390 return f"https://onlinejudge.u-aizu.ac.jp/services/room.html#{self.arena_id}/problems/{self.alphabet}"
392 @classmethod
393 def from_url(cls, url: str) -> Optional["AOJArenaProblem"]:
394 # example: https://onlinejudge.u-aizu.ac.jp/services/room.html#RitsCamp19Day2/problems/A
395 result = urllib.parse.urlparse(url)
396 if (
397 result.scheme in ("", "http", "https")
398 and result.netloc == "onlinejudge.u-aizu.ac.jp"
399 and _normpath(result.path) == "/services/room.html"
400 ):
401 fragment = result.fragment.split("/")
402 if len(fragment) == 3 and fragment[1] == "problems": # noqa: PLR2004 402 ↛ 404line 402 didn't jump to line 404 because the condition on line 402 was always true
403 return cls(arena_id=fragment[0], alphabet=fragment[2].upper())
404 return None
407@dataclass
408class LocalProblem(TestCaseProvider):
409 path: pathlib.Path
411 def download_system_cases(self) -> Iterable[TestCaseData] | bool:
412 return bool(any(self.iter_system_cases()))
414 def iter_system_cases(self) -> Iterable[TestCaseFile]:
415 return iter_testcases(directory=self.path, recursive=True)
418def _normpath(path: str) -> str:
419 """A wrapper of posixpath.normpath.
421 posixpath.normpath doesn't collapse a leading duplicated slashes.
422 """
423 path = posixpath.normpath(path)
424 if path.startswith("//"):
425 path = "/" + path.lstrip("/")
426 return path
429def _subclasses_recursive(cls: type[Problem]) -> Iterable[type[Problem]]:
430 yield from (children := cls.__subclasses__())
431 for ch in children:
432 yield from _subclasses_recursive(ch)
435def problem_from_url(url: str) -> Problem | None:
436 for ch in set(_subclasses_recursive(Problem)):
437 if (problem := ch.from_url(url)) is not None:
438 return problem
439 return None
442_InOut = TypeVar("_InOut")
445def enumerate_inouts(
446 inputs: dict[str, _InOut],
447 outputs: dict[str, _InOut],
448) -> Iterator[tuple[str, _InOut, _InOut]]:
449 common_keys = inputs.keys() & outputs.keys()
450 if len(inputs) != len(common_keys) or len(outputs) != len(common_keys):
451 logger.warning("dangling output case")
453 if len(common_keys) == 0:
454 logger.warning("no cases found")
456 for key in sorted(common_keys):
457 yield (key, inputs[key], outputs[key])
460def merge_testcase_files(
461 inputs: dict[str, pathlib.Path],
462 outputs: dict[str, pathlib.Path],
463) -> Iterator[TestCaseFile]:
464 for name, i, o in enumerate_inouts(inputs, outputs):
465 yield TestCaseFile(name=name, input_path=i, output_path=o)
468def _casename(path: pathlib.Path, *, directory: pathlib.Path) -> str:
469 return path.relative_to(directory).with_suffix("").as_posix()
472def iter_testcases(
473 *, directory: pathlib.Path, recursive: bool = False
474) -> Iterator[TestCaseFile]:
475 inputs: dict[str, pathlib.Path] = {}
476 outputs: dict[str, pathlib.Path] = {}
477 pre = "**/" if recursive else ""
479 for path in directory.glob(pre + "*.in"):
480 if path.is_file(): 480 ↛ 479line 480 didn't jump to line 479 because the condition on line 480 was always true
481 inputs[_casename(path, directory=directory)] = path
482 for path in directory.glob(pre + "*.out"):
483 if path.is_file(): 483 ↛ 482line 483 didn't jump to line 482 because the condition on line 483 was always true
484 outputs[_casename(path, directory=directory)] = path
486 return merge_testcase_files(inputs, outputs)
489def _name_to_filename(name: str, ext: str):
490 return pathlib.Path(name).with_suffix(f".{ext}").name
493def save_testcases(samples: Iterable[TestCaseData], *, directory: pathlib.Path):
494 for sample in samples:
495 for data, ext in [(sample.input_data, "in"), (sample.output_data, "out")]:
496 path = directory / _name_to_filename(sample.name, ext)
498 if path.exists(): 498 ↛ 499line 498 didn't jump to line 499 because the condition on line 498 was never true
499 logger.error("Failed to download since file already exists: %s", path)
500 path.parent.mkdir(parents=True, exist_ok=True)
501 path.write_bytes(data)
502 logger.debug("saved to: %s", path)