Coverage for src / competitive_verifier / oj / oj_test.py: 100%

211 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-05 16:00 +0000

1import contextlib 

2import math 

3import os 

4import pathlib 

5import platform 

6import shlex 

7import shutil 

8import signal 

9import subprocess 

10import sys 

11import tempfile 

12import time 

13from collections import Counter 

14from dataclasses import dataclass 

15from logging import getLogger 

16from typing import BinaryIO 

17 

18from competitive_verifier.log import GitHubMessageParams 

19from competitive_verifier.models import ( 

20 JudgeStatus, 

21 ResultStatus, 

22 TestCaseProvider, 

23 TestcaseResult, 

24 VerifcationTimeoutError, 

25 VerificationResult, 

26) 

27 

28from . import gnu 

29from .format import Printer, green, red 

30 

31logger = getLogger(__name__) 

32 

33 

34class CaseExecutionError(Exception): 

35 pass 

36 

37 

38@dataclass 

39class OjExecInfo: 

40 answer: str | None 

41 """The standard output of the executed command""" 

42 elapsed: float 

43 """The elapsed time of the executed command in seconds""" 

44 memory: float | None 

45 """The maximum memory usage of the executed command in megabytes""" 

46 returncode: int | None 

47 """The returncode of the executed command""" 

48 

49 

50def measure_command( 

51 command: list[str] | str, 

52 *, 

53 env: dict[str, str] | None = None, 

54 stdin: BinaryIO | int | None = None, 

55 timeout: float | None = None, 

56 gnu_time: bool = False, 

57) -> OjExecInfo: 

58 if isinstance(command, str): 

59 command = shlex.split(command) 

60 

61 if len(command) == 0: 

62 raise CaseExecutionError 

63 

64 with gnu.GnuTimeWrapper(enabled=gnu_time) as gw: 

65 if shutil.which(command[0]) is None: 

66 raise CaseExecutionError 

67 

68 command = gw.get_command(command) 

69 begin = time.perf_counter() 

70 

71 # We need kill processes called from the "time" command using process groups. Without this, orphans spawn. see https://github.com/kmyk/online-judge-tools/issues/640 

72 start_new_session = gnu.time_command() is not None and os.name == "posix" 

73 

74 try: 

75 if env: 

76 env = os.environ | env 

77 proc = subprocess.Popen( 

78 command, 

79 stdin=stdin, 

80 stdout=subprocess.PIPE, 

81 env=env, 

82 stderr=sys.stderr, 

83 encoding="utf-8", 

84 start_new_session=start_new_session, 

85 ) 

86 except Exception as e: 

87 logger.exception( 

88 "'%s' is not executable.", 

89 command, 

90 extra={"github": GitHubMessageParams()}, 

91 ) 

92 raise CaseExecutionError from e 

93 

94 try: 

95 answer, _ = proc.communicate(timeout=timeout) 

96 except subprocess.TimeoutExpired: 

97 answer = None 

98 finally: # pragma: no cover 

99 if start_new_session: 

100 with contextlib.suppress(ProcessLookupError): 

101 os.killpg(os.getpgid(proc.pid), signal.SIGTERM) 

102 else: 

103 proc.terminate() 

104 

105 end = time.perf_counter() 

106 return OjExecInfo( 

107 answer=answer, 

108 elapsed=end - begin, 

109 memory=gw.get_memory(), 

110 returncode=proc.returncode, 

111 ) 

112 

113 

114@dataclass 

115class OjTestArguments: 

116 """Parameters for oj-test command. 

117 

118 Port of onlinejudge_command.subcommand.test.add_subparser. 

119 """ 

120 

121 command: str | list[str] 

122 problem: TestCaseProvider 

123 tle: float | None 

124 mle: float | None 

125 error: float | None 

126 env: dict[str, str] | None = None 

127 deadline: float = float("inf") 

128 

129 

130@dataclass 

131class OjTestcaseResult: 

132 name: str 

133 """A name of the test case.""" 

134 input: pathlib.Path 

135 """A input of the test case.""" 

136 answer: str 

137 """A output of the test case.""" 

138 expected: pathlib.Path 

139 """A expected output of the test case.""" 

140 

141 status: JudgeStatus 

142 elapsed: float 

143 exitcode: int | None 

144 

145 memory: float | None = None 

146 

147 def __post_init__(self): 

148 if not isinstance(self.exitcode, int): 

149 self.exitcode = None 

150 

151 def __str__(self) -> str: 

152 p = [ 

153 f"{self.name}: {green('AC')}" 

154 if self.status == JudgeStatus.AC 

155 else f"{self.name}: {red(self.status.name)}", 

156 f"time: {self.elapsed:f} sec", 

157 f"memory: {self.memory:f} MB" if self.memory is not None else None, 

158 f"return code: {self.exitcode}" if self.exitcode else None, 

159 ] 

160 

161 return ", ".join(filter(None, p)) 

162 

163 def log(self): 

164 match self.status: 

165 case JudgeStatus.AC: 

166 pass 

167 case JudgeStatus.RE | JudgeStatus.TLE: 

168 self._log_input() 

169 self._log_expected() 

170 case _: 

171 self._log_input() 

172 self._log_answer() 

173 self._log_expected() 

174 logger.info(self) 

175 

176 def _log_input(self) -> None: 

177 logger.info("%s:input: %s", self.name, Printer(self.input)) 

178 

179 def _log_expected(self) -> None: 

180 logger.info("%s:expected: %s", self.name, Printer(self.expected)) 

181 

182 def _log_answer(self) -> None: 

183 logger.info("%s:answer: %s", self.name, Printer(self.answer)) 

184 

185 

186@dataclass 

187class OjTestResult: 

188 is_success: bool 

189 

190 elapsed: float 

191 

192 slowest: float 

193 """max time [seconds] 

194 """ 

195 

196 heaviest: float 

197 """max memory [MB] 

198 """ 

199 

200 testcases: list[OjTestcaseResult] 

201 

202 

203def _try_parse_float(value: str) -> float | None: 

204 try: 

205 return float(value) 

206 except ValueError: 

207 return None 

208 

209 

210def _equal_or_closed_float(actual: str, expected: str, *, error: float) -> bool: 

211 if actual == expected: 

212 return True 

213 

214 x = _try_parse_float(actual) 

215 y = _try_parse_float(expected) 

216 

217 return ( 

218 x is not None 

219 and y is not None 

220 and math.isclose(x, y, rel_tol=error, abs_tol=error) 

221 ) 

222 

223 

224def compare_answer(actual: str, expected: str, *, error: float | None) -> bool: 

225 """Compare two byte strings. 

226 

227 Args: 

228 actual (bytes): Actual output 

229 expected (bytes): Expected output 

230 error (float | None): Margin of error 

231 Returns: 

232 bool: True if they are considered equal 

233 """ 

234 actual = actual.replace("\r\n", "\n") 

235 expected = expected.replace("\r\n", "\n") 

236 

237 # match 

238 if actual == expected: 

239 return True 

240 

241 try: 

242 if error is None: 

243 actual_words = actual.split() 

244 expected_words = expected.split() 

245 if all(x == y for x, y in zip(actual_words, expected_words, strict=True)): 

246 logger.warning("This was AC if spaces and newlines were ignored.") 

247 return False 

248 

249 actual_lines = actual.rstrip("\n").split("\n") 

250 expected_lines = expected.rstrip("\n").split("\n") 

251 

252 for actual_line, expected_line in zip( 

253 actual_lines, expected_lines, strict=True 

254 ): 

255 actual_words = actual_line.split() 

256 expected_words = expected_line.split() 

257 

258 for x, y in zip(actual_words, expected_words, strict=True): 

259 if not _equal_or_closed_float(x, y, error=error): 

260 return False 

261 except ValueError: 

262 return False 

263 

264 return True 

265 

266 

267def special_judge( 

268 judge_command: str, 

269 output: str, 

270 *, 

271 input_path: pathlib.Path, 

272 expected_output_path: pathlib.Path | None, 

273) -> bool: 

274 with tempfile.TemporaryDirectory() as tempdir: 

275 actual_output_path = pathlib.Path(tempdir) / "actual.out" 

276 actual_output_path.write_text(output) 

277 

278 command = [ 

279 *shlex.split(judge_command), 

280 str(input_path.resolve()), 

281 str(actual_output_path.resolve()), 

282 str( 

283 expected_output_path.resolve() 

284 if expected_output_path is not None 

285 else "" 

286 ), 

287 ] 

288 

289 logger.debug("$ %s", command) 

290 info = measure_command(command) 

291 logger.debug("judge's output: %s", Printer(info.answer or "")) 

292 return info.returncode == 0 

293 

294 

295def determine_status( 

296 *, 

297 exitcode: int | None, 

298 memory: float | None, 

299 mle: float | None, 

300 match_result: bool | None, 

301) -> JudgeStatus: 

302 if exitcode is None: 

303 return JudgeStatus.TLE 

304 if memory is not None and mle is not None and memory > mle: 

305 return JudgeStatus.MLE 

306 if exitcode != 0: 

307 return JudgeStatus.RE 

308 if match_result is not None and not match_result: 

309 return JudgeStatus.WA 

310 return JudgeStatus.AC 

311 

312 

313def single_case( 

314 test_name: str, 

315 test_input_path: pathlib.Path, 

316 test_output_path: pathlib.Path, 

317 *, 

318 args: OjTestArguments, 

319) -> OjTestcaseResult: 

320 try: 

321 logger.info("%s: start", test_name) 

322 

323 # run the binary 

324 with test_input_path.open("rb") as infp: 

325 info = measure_command( 

326 args.command, 

327 env=args.env, 

328 stdin=infp, 

329 timeout=args.tle, 

330 gnu_time=True, 

331 ) 

332 answer = info.answer or "" 

333 elapsed: float = info.elapsed 

334 memory: float | None = info.memory 

335 

336 match_result = ( 

337 special_judge( 

338 str(args.problem.checker), 

339 answer, 

340 input_path=test_input_path, 

341 expected_output_path=test_output_path, 

342 ) 

343 if args.problem.checker 

344 else compare_answer(answer, test_output_path.read_text(), error=args.error) 

345 ) 

346 

347 status = determine_status( 

348 exitcode=info.returncode, 

349 memory=memory, 

350 mle=args.mle, 

351 match_result=match_result, 

352 ) 

353 

354 result = OjTestcaseResult( 

355 name=test_name, 

356 input=test_input_path, 

357 expected=test_output_path, 

358 answer=answer, 

359 status=status, 

360 exitcode=info.returncode, 

361 elapsed=elapsed, 

362 memory=memory, 

363 ) 

364 except CaseExecutionError: 

365 logger.exception( 

366 "Failed to run: %s", 

367 args, 

368 extra={"github": GitHubMessageParams()}, 

369 ) 

370 return OjTestcaseResult( 

371 name=test_name, 

372 input=test_input_path, 

373 expected=test_output_path, 

374 answer="", 

375 status=JudgeStatus.RE, 

376 exitcode=255, 

377 elapsed=0, 

378 memory=None, 

379 ) 

380 else: 

381 result.log() 

382 return result 

383 

384 

385def gnu_time_message(args: OjTestArguments): 

386 """Check wheather GNU time is available. 

387 

388 Show messages if GNU time is not available. 

389 """ 

390 if gnu.time_command() is None: 

391 if platform.system() == "Darwin": 

392 logger.info( 

393 "[HINT]: You can install GNU time with: $ brew install gnu-time", 

394 extra={"github": GitHubMessageParams()}, 

395 ) 

396 if args.mle is not None: 

397 logger.warning( 

398 "--mle is used but GNU time does not exist", 

399 extra={"github": GitHubMessageParams()}, 

400 ) 

401 

402 

403class _StatusCounter(Counter[JudgeStatus]): 

404 def __str__(self) -> str: 

405 return ", ".join( 

406 f"{cnt} {name}" 

407 for name, cnt in ((st.name, self.get(st)) for st in JudgeStatus) 

408 if cnt 

409 ) 

410 

411 

412def summarize(history: list[OjTestcaseResult]): 

413 elapsed: float = 0.0 

414 slowest: float = -1.0 

415 slowest_name: str | None = None 

416 heaviest: float = -1.0 

417 heaviest_name: str | None = None 

418 counter = _StatusCounter() 

419 for result in history: 

420 counter[result.status] += 1 

421 elapsed += result.elapsed 

422 if slowest < result.elapsed: 

423 slowest = result.elapsed 

424 slowest_name = result.name 

425 if result.memory is not None and heaviest < result.memory: 

426 heaviest = result.memory 

427 heaviest_name = result.name 

428 

429 # print the summary 

430 if slowest_name is not None: 

431 logger.info("slowest: %f sec (for %s)", slowest, slowest_name) 

432 if heaviest_name is not None: 

433 logger.info("max memory: %f MB (for %s)", heaviest, heaviest_name) 

434 

435 length = len(history) 

436 is_success = counter[JudgeStatus.AC] == length 

437 if is_success: 

438 logger.info("%s %d cases", green("SUCCESS"), length) 

439 else: 

440 logger.info("%s %s / %d cases", red("FAILURE"), counter, length) 

441 

442 # return the result 

443 return OjTestResult( 

444 is_success=is_success, 

445 slowest=slowest, 

446 elapsed=elapsed, 

447 heaviest=heaviest, 

448 testcases=history, 

449 ) 

450 

451 

452def _run(args: OjTestArguments) -> OjTestResult: 

453 gnu_time_message(args) 

454 

455 if args.error is not None and args.error > 1: 

456 logger.warning( 

457 "the tolerance is too large: relative = %s", 

458 args.error, 

459 extra={"github": GitHubMessageParams()}, 

460 ) 

461 

462 tests = list(args.problem.iter_system_cases()) 

463 

464 # run tests 

465 history: list[OjTestcaseResult] = [] 

466 for t in tests: 

467 if time.perf_counter() > args.deadline: 

468 raise VerifcationTimeoutError 

469 

470 history.append(single_case(t.name, t.input_path, t.output_path, args=args)) 

471 

472 return summarize(history) 

473 

474 

475def main( 

476 *, 

477 problem: TestCaseProvider, 

478 command: str | list[str], 

479 env: dict[str, str] | None, 

480 tle: float | None, 

481 mle: float | None, 

482 error: float | None, 

483 deadline: float = float("inf"), 

484) -> VerificationResult: 

485 args = OjTestArguments( 

486 command=command, 

487 problem=problem, 

488 env=env, 

489 tle=tle, 

490 mle=mle, 

491 error=error, 

492 deadline=deadline, 

493 ) 

494 result = _run(args) 

495 return VerificationResult( 

496 status=ResultStatus.SUCCESS if result.is_success else ResultStatus.FAILURE, 

497 elapsed=result.elapsed, 

498 slowest=result.slowest, 

499 heaviest=result.heaviest, 

500 testcases=[ 

501 TestcaseResult( 

502 name=case.name, 

503 elapsed=case.elapsed, 

504 memory=case.memory, 

505 status=case.status, 

506 ) 

507 for case in result.testcases 

508 ], 

509 )