Coverage for src / competitive_verifier / oj / problem.py: 82%

301 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-05 16:00 +0000

1import glob 

2import json 

3import os 

4import pathlib 

5import posixpath 

6import re 

7import subprocess 

8import sys 

9import urllib.parse 

10import zipfile 

11from abc import abstractmethod 

12from collections.abc import Iterable, Iterator 

13from dataclasses import dataclass 

14from io import BytesIO 

15from logging import getLogger 

16from typing import ClassVar, Optional, TypeVar 

17 

18import requests 

19 

20from competitive_verifier import config 

21from competitive_verifier.log import GitHubMessageParams 

22from competitive_verifier.models import ( 

23 Problem, 

24 TestCaseData, 

25 TestCaseFile, 

26 TestCaseProvider, 

27) 

28 

29logger = getLogger(__name__) 

30 

31 

32class NotLoggedInError(RuntimeError): 

33 pass 

34 

35 

36class _BaseProblem(Problem): 

37 def iter_system_cases(self) -> Iterator[TestCaseFile]: 

38 return iter_testcases(directory=self.test_directory) 

39 

40 def download_system_cases(self) -> Iterable[TestCaseData] | bool: 

41 test_directory = self.test_directory 

42 

43 if test_directory.exists() and any(test_directory.iterdir()): 

44 logger.info("download:already exists: %s", self.url) 

45 return True 

46 

47 self.problem_directory.mkdir(parents=True, exist_ok=True) 

48 

49 samples = list(self._download_cases()) 

50 

51 # Check samples 

52 if not samples: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 logger.error( 

54 "Sample not found", 

55 extra={"github": GitHubMessageParams()}, 

56 ) 

57 return False 

58 

59 # write samples to files 

60 save_testcases(samples, directory=test_directory) 

61 return samples 

62 

63 @abstractmethod 

64 def _download_cases(self) -> Iterable[TestCaseData]: ... 

65 

66 

67class LibraryCheckerProblem(Problem): 

68 checker_exe_name: ClassVar[str] = ( 

69 "checker.exe" if sys.platform == "win32" else "checker" 

70 ) 

71 

72 def __init__(self, *, problem_id: str): 

73 self.problem_id = problem_id 

74 self._source_directory = None 

75 

76 def __hash__(self) -> int: 

77 return hash((self.problem_id, self.repo_path)) 

78 

79 def __eq__(self, value: object) -> bool: 

80 if not isinstance(value, LibraryCheckerProblem): 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 return False 

82 return self.problem_id == value.problem_id and self.repo_path == value.repo_path 

83 

84 @property 

85 def repo_path(self): 

86 return config.get_cache_dir() / "library-checker-problems" 

87 

88 def iter_system_cases(self) -> Iterator[TestCaseFile]: 

89 inputs: dict[str, pathlib.Path] = {} 

90 outputs: dict[str, pathlib.Path] = {} 

91 for path in self.source_directory.glob("in/*.in"): 

92 inputs[path.stem] = path 

93 for path in self.source_directory.glob("out/*.out"): 

94 outputs[path.stem] = path 

95 return merge_testcase_files(inputs, outputs) 

96 

97 def download_system_cases(self) -> bool: 

98 self.problem_directory.mkdir(parents=True, exist_ok=True) 

99 self.generate_test_cases() 

100 return True 

101 

102 @property 

103 def checker(self) -> pathlib.Path | None: 

104 return self.source_directory / self.checker_exe_name 

105 

106 def generate_test_cases(self) -> None: 

107 self.update_cloned_repository() 

108 path = self.repo_path 

109 

110 spec = str(self.source_directory / "info.toml") 

111 command = [sys.executable, str(path / "generate.py"), spec] 

112 logger.info("$ %s", " ".join(command)) 

113 try: 

114 subprocess.check_call(command, stdout=sys.stderr, stderr=sys.stderr) 

115 except subprocess.CalledProcessError: 

116 logger.exception( 

117 "the generate.py failed: check https://github.com/yosupo06/library-checker-problems/issues", 

118 extra={"github": GitHubMessageParams()}, 

119 ) 

120 raise 

121 

122 @property 

123 def source_directory(self): 

124 if self._source_directory is None: 

125 problem_id = self.problem_id 

126 info_tomls = list( 

127 self.repo_path.glob(f"**/{glob.escape(problem_id)}/info.toml") 

128 ) 

129 if len(info_tomls) != 1: 129 ↛ 130line 129 didn't jump to line 130 because the condition on line 129 was never true

130 raise RuntimeError(f"the problem {problem_id!r} not found or broken") 

131 self._source_directory = info_tomls[0].parent 

132 return self._source_directory 

133 

134 @property 

135 def url(self) -> str: 

136 return f"https://judge.yosupo.jp/problem/{self.problem_id}" 

137 

138 @classmethod 

139 def from_url(cls, url: str) -> Optional["LibraryCheckerProblem"]: 

140 # example: https://judge.yosupo.jp/problem/unionfind 

141 result = urllib.parse.urlparse(url) 

142 if result.scheme in ("", "http", "https") and result.netloc in ( 

143 "judge.yosupo.jp", 

144 "old.yosupo.jp", 

145 ): 

146 m = re.match(r"/problem/(\w+)/?", result.path) 

147 if m: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true

148 return cls(problem_id=m.group(1)) 

149 return None 

150 

151 _is_repository_updated = False 

152 

153 def update_cloned_repository(self) -> None: 

154 if self._is_repository_updated: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 return 

156 

157 try: 

158 subprocess.check_call( 

159 ["git", "--version"], # noqa: S607 

160 stdout=sys.stderr, 

161 stderr=sys.stderr, 

162 ) 

163 except FileNotFoundError: 

164 logger.exception( 

165 "git command not found", 

166 exc_info=False, 

167 extra={"github": GitHubMessageParams()}, 

168 ) 

169 raise 

170 

171 path = self.repo_path 

172 if not path.exists(): 172 ↛ 183line 172 didn't jump to line 183 because the condition on line 172 was always true

173 # init the problem repository 

174 url = "https://github.com/yosupo06/library-checker-problems" 

175 logger.info("$ git clone %s %s", url, path) 

176 subprocess.check_call( 

177 ["git", "clone", url, str(path)], # noqa: S607 

178 stdout=sys.stderr, 

179 stderr=sys.stderr, 

180 ) 

181 else: 

182 # sync the problem repository 

183 logger.info("$ git -C %s pull", path) 

184 subprocess.check_call( 

185 ["git", "-C", str(path), "pull"], # noqa: S607 

186 stdout=sys.stderr, 

187 stderr=sys.stderr, 

188 ) 

189 

190 self._is_repository_updated = True 

191 

192 

193class _YukicoderProblemNo(int): 

194 def __new__(cls, value: int): 

195 return super().__new__(cls, value) 

196 

197 def __str__(self) -> str: 

198 return "no/" + super().__str__() 

199 

200 

201class _YukicoderProblemId(int): 

202 def __new__(cls, value: int): 

203 return super().__new__(cls, value) 

204 

205 

206class YukicoderProblem(_BaseProblem): 

207 problem: _YukicoderProblemNo | _YukicoderProblemId 

208 

209 def __init__(self, *, problem_no: int | None = None, problem_id: int | None = None): 

210 if problem_no is not None: 

211 self.problem = _YukicoderProblemNo(problem_no) 

212 elif problem_id is not None: 212 ↛ 215line 212 didn't jump to line 215 because the condition on line 212 was always true

213 self.problem = _YukicoderProblemId(problem_id) 

214 else: 

215 raise ValueError("Needs problem_no or problem_id") 

216 

217 def _download_cases(self) -> list[TestCaseData]: 

218 """Download yukicoder problem. 

219 

220 Raises: 

221 NotLoggedInError: If the `cargo metadata` command fails 

222 """ 

223 headers: dict[str, str] | None = None 

224 if yukicoder_token := os.environ.get("YUKICODER_TOKEN"): 

225 headers = {"Authorization": f"Bearer {yukicoder_token}"} 

226 

227 if not self._is_logged_in(headers=headers): 

228 raise NotLoggedInError("Required: $YUKICODER_TOKEN environment variable") 

229 url = f"{self.url}/testcase.zip" 

230 resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10) 

231 

232 with zipfile.ZipFile(BytesIO(resp.content)) as fh: 

233 inputs: dict[str, bytes] = {} 

234 outputs: dict[str, bytes] = {} 

235 for filename in fh.namelist(): 

236 if filename.endswith("/"): 

237 continue 

238 file = fh.read(filename) 

239 path = pathlib.Path(filename) 

240 if filename.startswith("test_in/"): 

241 inputs[path.stem] = file 

242 elif filename.startswith("test_out/"): 

243 outputs[path.stem] = file 

244 return [ 

245 TestCaseData(name=name, input_data=i, output_data=o) 

246 for name, i, o in enumerate_inouts(inputs, outputs) 

247 ] 

248 

249 @property 

250 def url(self) -> str: 

251 return f"https://yukicoder.me/problems/{self.problem}" 

252 

253 @classmethod 

254 def from_url(cls, url: str) -> Optional["YukicoderProblem"]: 

255 # example: https://yukicoder.me/problems/no/499 

256 # example: http://yukicoder.me/problems/1476 

257 result = urllib.parse.urlparse(url) 

258 dirname, basename = posixpath.split(_normpath(result.path)) 

259 if result.scheme in ("", "http", "https") and result.netloc == "yukicoder.me": 

260 try: 

261 n = int(basename) 

262 except ValueError: 

263 pass 

264 else: 

265 if dirname == "/problems/no": 

266 return cls(problem_no=n) 

267 if dirname == "/problems": 

268 return cls(problem_id=n) 

269 return None 

270 

271 def _is_logged_in(self, *, headers: dict[str, str] | None = None) -> bool: 

272 url = "https://yukicoder.me" 

273 resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10) 

274 resp.raise_for_status() 

275 return "login-btn" not in str(resp.content) 

276 

277 

278class AOJProblem(_BaseProblem): 

279 def __init__(self, *, problem_id: str): 

280 self.problem_id = problem_id 

281 

282 def _download_cases(self) -> Iterable[TestCaseData]: 

283 return AOJProblem.download_cases(self.problem_id) 

284 

285 @staticmethod 

286 def download_cases(problem_id: str) -> Iterable[TestCaseData]: 

287 # get header 

288 # reference: http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2F%7BproblemId%7D%2Fheader_GET 

289 url = f"https://judgedat.u-aizu.ac.jp/testcases/{problem_id}/header" 

290 resp = requests.get(url, allow_redirects=True, timeout=10) 

291 resp.raise_for_status() 

292 header_res = json.loads(resp.text) 

293 

294 # get testcases via the official API 

295 for header in header_res["headers"]: 

296 # NOTE: the endpoints are not same to http://developers.u-aizu.ac.jp/api?key=judgedat%2Ftestcases%2F%7BproblemId%7D%2F%7Bserial%7D_GET since the json API often says "..... (terminated because of the limitation)" 

297 # NOTE: even when using https://judgedat.u-aizu.ac.jp/testcases/PROBLEM_ID/SERIAL, there is the 1G limit (see https://twitter.com/beet_aizu/status/1194947611100188672) 

298 serial = header["serial"] 

299 url = f"https://judgedat.u-aizu.ac.jp/testcases/{problem_id}/{serial}" 

300 

301 resp_in = requests.get(url + "/in", allow_redirects=True, timeout=10) 

302 resp_in.raise_for_status() 

303 resp_out = requests.get(url + "/out", allow_redirects=True, timeout=10) 

304 resp_out.raise_for_status() 

305 

306 yield TestCaseData( 

307 header["name"], 

308 resp_in.content, 

309 resp_out.content, 

310 ) 

311 

312 @property 

313 def url(self) -> str: 

314 return f"http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id={self.problem_id}" 

315 

316 @classmethod 

317 def from_url(cls, url: str) -> Optional["AOJProblem"]: 

318 result = urllib.parse.urlparse(url) 

319 

320 # example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1169 

321 # example: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_1_A&lang=jp 

322 querystring = urllib.parse.parse_qs(result.query) 

323 if ( 

324 result.scheme in ("", "http", "https") 

325 and result.netloc == "judge.u-aizu.ac.jp" 

326 and _normpath(result.path) == "/onlinejudge/description.jsp" 

327 and querystring.get("id") 

328 and len(querystring["id"]) == 1 

329 ): 

330 (n,) = querystring["id"] 

331 return cls(problem_id=n) 

332 

333 # example: https://onlinejudge.u-aizu.ac.jp/challenges/sources/JAG/Prelim/2881 

334 # example: https://onlinejudge.u-aizu.ac.jp/courses/library/4/CGL/3/CGL_3_B 

335 m = re.match( 

336 r"^/(challenges|courses)/(sources|library/\d+|lesson/\d+)/(\w+)/(\w+)/(\w+)$", 

337 _normpath(result.path), 

338 ) 

339 if ( 

340 result.scheme in ("", "http", "https") 

341 and result.netloc == "onlinejudge.u-aizu.ac.jp" 

342 and m 

343 ): 

344 n = m.group(5) 

345 return cls(problem_id=n) 

346 

347 # example: https://onlinejudge.u-aizu.ac.jp/problems/0423 

348 # example: https://onlinejudge.u-aizu.ac.jp/problems/CGL_3_B 

349 m = re.match(r"^/problems/(\w+)$", _normpath(result.path)) 

350 if ( 

351 result.scheme in ("", "http", "https") 

352 and result.netloc == "onlinejudge.u-aizu.ac.jp" 

353 and m 

354 ): 

355 n = m.group(1) 

356 return cls(problem_id=n) 

357 

358 return None 

359 

360 

361class AOJArenaProblem(_BaseProblem): 

362 def __init__(self, *, arena_id: str, alphabet: str): 

363 if len(alphabet) != 1 or not alphabet.isupper(): 363 ↛ 364line 363 didn't jump to line 364 because the condition on line 363 was never true

364 raise ValueError(arena_id, alphabet) 

365 self.arena_id = arena_id 

366 self.alphabet = alphabet 

367 

368 self._problem_id: str | None = None 

369 

370 def get_problem_id(self) -> str: 

371 if self._problem_id is None: 371 ↛ 383line 371 didn't jump to line 383 because the condition on line 371 was always true

372 url = f"https://judgeapi.u-aizu.ac.jp/arenas/{self.arena_id}/problems" 

373 resp = requests.get(url, allow_redirects=True, timeout=10) 

374 resp.raise_for_status() 

375 problems = json.loads(resp.text) 

376 for problem in problems: 376 ↛ 382line 376 didn't jump to line 382 because the loop on line 376 didn't complete

377 if problem["id"] == self.alphabet: 377 ↛ 376line 377 didn't jump to line 376 because the condition on line 377 was always true

378 p = problem["problemId"] 

379 logger.debug("problem: %s", p) 

380 self._problem_id = p 

381 return p 

382 raise ValueError("Problem is not found.") 

383 return self._problem_id 

384 

385 def _download_cases(self) -> Iterable[TestCaseData]: 

386 return AOJProblem.download_cases(self.get_problem_id()) 

387 

388 @property 

389 def url(self) -> str: 

390 return f"https://onlinejudge.u-aizu.ac.jp/services/room.html#{self.arena_id}/problems/{self.alphabet}" 

391 

392 @classmethod 

393 def from_url(cls, url: str) -> Optional["AOJArenaProblem"]: 

394 # example: https://onlinejudge.u-aizu.ac.jp/services/room.html#RitsCamp19Day2/problems/A 

395 result = urllib.parse.urlparse(url) 

396 if ( 

397 result.scheme in ("", "http", "https") 

398 and result.netloc == "onlinejudge.u-aizu.ac.jp" 

399 and _normpath(result.path) == "/services/room.html" 

400 ): 

401 fragment = result.fragment.split("/") 

402 if len(fragment) == 3 and fragment[1] == "problems": # noqa: PLR2004 402 ↛ 404line 402 didn't jump to line 404 because the condition on line 402 was always true

403 return cls(arena_id=fragment[0], alphabet=fragment[2].upper()) 

404 return None 

405 

406 

407@dataclass 

408class LocalProblem(TestCaseProvider): 

409 path: pathlib.Path 

410 

411 def download_system_cases(self) -> Iterable[TestCaseData] | bool: 

412 return bool(any(self.iter_system_cases())) 

413 

414 def iter_system_cases(self) -> Iterable[TestCaseFile]: 

415 return iter_testcases(directory=self.path, recursive=True) 

416 

417 

418def _normpath(path: str) -> str: 

419 """A wrapper of posixpath.normpath. 

420 

421 posixpath.normpath doesn't collapse a leading duplicated slashes. 

422 """ 

423 path = posixpath.normpath(path) 

424 if path.startswith("//"): 

425 path = "/" + path.lstrip("/") 

426 return path 

427 

428 

429def _subclasses_recursive(cls: type[Problem]) -> Iterable[type[Problem]]: 

430 yield from (children := cls.__subclasses__()) 

431 for ch in children: 

432 yield from _subclasses_recursive(ch) 

433 

434 

435def problem_from_url(url: str) -> Problem | None: 

436 for ch in set(_subclasses_recursive(Problem)): 

437 if (problem := ch.from_url(url)) is not None: 

438 return problem 

439 return None 

440 

441 

442_InOut = TypeVar("_InOut") 

443 

444 

445def enumerate_inouts( 

446 inputs: dict[str, _InOut], 

447 outputs: dict[str, _InOut], 

448) -> Iterator[tuple[str, _InOut, _InOut]]: 

449 common_keys = inputs.keys() & outputs.keys() 

450 if len(inputs) != len(common_keys) or len(outputs) != len(common_keys): 

451 logger.warning("dangling output case") 

452 

453 if len(common_keys) == 0: 

454 logger.warning("no cases found") 

455 

456 for key in sorted(common_keys): 

457 yield (key, inputs[key], outputs[key]) 

458 

459 

460def merge_testcase_files( 

461 inputs: dict[str, pathlib.Path], 

462 outputs: dict[str, pathlib.Path], 

463) -> Iterator[TestCaseFile]: 

464 for name, i, o in enumerate_inouts(inputs, outputs): 

465 yield TestCaseFile(name=name, input_path=i, output_path=o) 

466 

467 

468def _casename(path: pathlib.Path, *, directory: pathlib.Path) -> str: 

469 return path.relative_to(directory).with_suffix("").as_posix() 

470 

471 

472def iter_testcases( 

473 *, directory: pathlib.Path, recursive: bool = False 

474) -> Iterator[TestCaseFile]: 

475 inputs: dict[str, pathlib.Path] = {} 

476 outputs: dict[str, pathlib.Path] = {} 

477 pre = "**/" if recursive else "" 

478 

479 for path in directory.glob(pre + "*.in"): 

480 if path.is_file(): 480 ↛ 479line 480 didn't jump to line 479 because the condition on line 480 was always true

481 inputs[_casename(path, directory=directory)] = path 

482 for path in directory.glob(pre + "*.out"): 

483 if path.is_file(): 483 ↛ 482line 483 didn't jump to line 482 because the condition on line 483 was always true

484 outputs[_casename(path, directory=directory)] = path 

485 

486 return merge_testcase_files(inputs, outputs) 

487 

488 

489def _name_to_filename(name: str, ext: str): 

490 return pathlib.Path(name).with_suffix(f".{ext}").name 

491 

492 

493def save_testcases(samples: Iterable[TestCaseData], *, directory: pathlib.Path): 

494 for sample in samples: 

495 for data, ext in [(sample.input_data, "in"), (sample.output_data, "out")]: 

496 path = directory / _name_to_filename(sample.name, ext) 

497 

498 if path.exists(): 498 ↛ 499line 498 didn't jump to line 499 because the condition on line 498 was never true

499 logger.error("Failed to download since file already exists: %s", path) 

500 path.parent.mkdir(parents=True, exist_ok=True) 

501 path.write_bytes(data) 

502 logger.debug("saved to: %s", path)