Coverage for src / competitive_verifier / oj / oj_test.py: 100%

210 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-04-26 12:38 +0900

1import math 

2import os 

3import pathlib 

4import platform 

5import shlex 

6import shutil 

7import subprocess 

8import sys 

9import tempfile 

10import time 

11from collections import Counter 

12from dataclasses import dataclass 

13from logging import getLogger 

14from typing import BinaryIO 

15 

16from competitive_verifier.log import GitHubMessageParams 

17from competitive_verifier.models import ( 

18 JudgeStatus, 

19 ResultStatus, 

20 TestCaseProvider, 

21 TestcaseResult, 

22 VerifcationTimeoutError, 

23 VerificationResult, 

24) 

25 

26from . import gnu 

27from .format import Printer, green, red 

28 

29logger = getLogger(__name__) 

30 

31 

32class CaseExecutionError(Exception): 

33 pass 

34 

35 

36@dataclass 

37class OjExecInfo: 

38 answer: str | None 

39 """The standard output of the executed command""" 

40 elapsed: float 

41 """The elapsed time of the executed command in seconds""" 

42 memory: float | None 

43 """The maximum memory usage of the executed command in megabytes""" 

44 returncode: int | None 

45 """The returncode of the executed command""" 

46 

47 

48def measure_command( 

49 command: list[str] | str, 

50 *, 

51 env: dict[str, str] | None = None, 

52 stdin: BinaryIO | int | None = None, 

53 timeout: float | None = None, 

54 gnu_time: bool = False, 

55) -> OjExecInfo: 

56 if isinstance(command, str): 

57 command = shlex.split(command) 

58 

59 if len(command) == 0: 

60 raise CaseExecutionError 

61 

62 with gnu.GnuTimeWrapper(enabled=gnu_time) as gw: 

63 if shutil.which(command[0]) is None: 

64 raise CaseExecutionError 

65 

66 command = gw.get_command(command) 

67 begin = time.perf_counter() 

68 

69 # We need kill processes called from the "time" command using process groups. Without this, orphans spawn. see https://github.com/kmyk/online-judge-tools/issues/640 

70 start_new_session = gw.gnu_time is not None 

71 

72 try: 

73 if env is not None: 

74 env = os.environ | env 

75 proc = subprocess.run( 

76 command, 

77 env=env, 

78 timeout=timeout, 

79 stdin=stdin, 

80 stdout=subprocess.PIPE, 

81 stderr=sys.stderr, 

82 encoding="utf-8", 

83 start_new_session=start_new_session, 

84 check=False, 

85 ) 

86 answer = proc.stdout 

87 returncode = proc.returncode 

88 except subprocess.TimeoutExpired: 

89 answer = None 

90 returncode = None 

91 except Exception as e: 

92 logger.exception( 

93 "'%s' is not executable.", 

94 command, 

95 extra={"github": GitHubMessageParams()}, 

96 ) 

97 raise CaseExecutionError from e 

98 

99 end = time.perf_counter() 

100 return OjExecInfo( 

101 answer=answer, 

102 elapsed=end - begin, 

103 memory=gw.get_memory(), 

104 returncode=returncode, 

105 ) 

106 

107 

108@dataclass 

109class OjTestArguments: 

110 """Parameters for oj-test command. 

111 

112 Port of onlinejudge_command.subcommand.test.add_subparser. 

113 """ 

114 

115 command: str | list[str] 

116 problem: TestCaseProvider 

117 tle: float | None 

118 mle: float | None 

119 error: float | None 

120 env: dict[str, str] | None = None 

121 deadline: float = float("inf") 

122 

123 

124@dataclass 

125class OjTestcaseResult: 

126 name: str 

127 """A name of the test case.""" 

128 input: pathlib.Path 

129 """A input of the test case.""" 

130 answer: str 

131 """A output of the test case.""" 

132 expected: pathlib.Path 

133 """A expected output of the test case.""" 

134 

135 status: JudgeStatus 

136 elapsed: float 

137 exitcode: int | None 

138 

139 memory: float | None = None 

140 

141 def __post_init__(self): 

142 if not isinstance(self.exitcode, int): 

143 self.exitcode = None 

144 

145 def __str__(self) -> str: 

146 p = [ 

147 f"{self.name}: {green('AC')}" 

148 if self.status == JudgeStatus.AC 

149 else f"{self.name}: {red(self.status.name)}", 

150 f"time: {self.elapsed:f} sec", 

151 f"memory: {self.memory:f} MB" if self.memory is not None else None, 

152 f"return code: {self.exitcode}" if self.exitcode else None, 

153 ] 

154 

155 return ", ".join(filter(None, p)) 

156 

157 def log(self): 

158 match self.status: 

159 case JudgeStatus.AC: 

160 pass 

161 case JudgeStatus.RE | JudgeStatus.TLE: 

162 self._log_input() 

163 self._log_expected() 

164 case _: 

165 self._log_input() 

166 self._log_answer() 

167 self._log_expected() 

168 logger.info(self) 

169 

170 def _log_input(self) -> None: 

171 logger.info("%s:input: %s", self.name, Printer(self.input)) 

172 

173 def _log_expected(self) -> None: 

174 logger.info("%s:expected: %s", self.name, Printer(self.expected)) 

175 

176 def _log_answer(self) -> None: 

177 logger.info("%s:answer: %s", self.name, Printer(self.answer)) 

178 

179 

180@dataclass 

181class OjTestResult: 

182 is_success: bool 

183 

184 elapsed: float 

185 

186 slowest: float 

187 """max time [seconds] 

188 """ 

189 

190 heaviest: float 

191 """max memory [MB] 

192 """ 

193 

194 testcases: list[OjTestcaseResult] 

195 

196 

197def _try_parse_float(value: str) -> float | None: 

198 try: 

199 return float(value) 

200 except ValueError: 

201 return None 

202 

203 

204def _equal_or_closed_float(actual: str, expected: str, *, error: float) -> bool: 

205 if actual == expected: 

206 return True 

207 

208 x = _try_parse_float(actual) 

209 y = _try_parse_float(expected) 

210 

211 return ( 

212 x is not None 

213 and y is not None 

214 and math.isclose(x, y, rel_tol=error, abs_tol=error) 

215 ) 

216 

217 

218def compare_answer(actual: str, expected: str, *, error: float | None) -> bool: 

219 """Compare two byte strings. 

220 

221 Args: 

222 actual (bytes): Actual output 

223 expected (bytes): Expected output 

224 error (float | None): Margin of error 

225 Returns: 

226 bool: True if they are considered equal 

227 """ 

228 actual = actual.replace("\r\n", "\n") 

229 expected = expected.replace("\r\n", "\n") 

230 

231 # match 

232 if actual == expected: 

233 return True 

234 

235 try: 

236 if error is None: 

237 actual_words = actual.split() 

238 expected_words = expected.split() 

239 if all(x == y for x, y in zip(actual_words, expected_words, strict=True)): 

240 logger.warning("This was AC if spaces and newlines were ignored.") 

241 return False 

242 

243 actual_lines = actual.rstrip("\n").split("\n") 

244 expected_lines = expected.rstrip("\n").split("\n") 

245 

246 for actual_line, expected_line in zip( 

247 actual_lines, expected_lines, strict=True 

248 ): 

249 actual_words = actual_line.split() 

250 expected_words = expected_line.split() 

251 

252 for x, y in zip(actual_words, expected_words, strict=True): 

253 if not _equal_or_closed_float(x, y, error=error): 

254 return False 

255 except ValueError: 

256 return False 

257 

258 return True 

259 

260 

261def special_judge( 

262 judge_command: str, 

263 output: str, 

264 *, 

265 input_path: pathlib.Path, 

266 expected_output_path: pathlib.Path | None, 

267) -> bool: 

268 with tempfile.TemporaryDirectory() as tempdir: 

269 actual_output_path = pathlib.Path(tempdir) / "actual.out" 

270 actual_output_path.write_text(output) 

271 

272 command = [ 

273 *shlex.split(judge_command), 

274 str(input_path.resolve()), 

275 str(actual_output_path.resolve()), 

276 str( 

277 expected_output_path.resolve() 

278 if expected_output_path is not None 

279 else "" 

280 ), 

281 ] 

282 

283 logger.debug("$ %s", command) 

284 info = measure_command(command) 

285 logger.debug("judge's output: %s", Printer(info.answer or "")) 

286 return info.returncode == 0 

287 

288 

289def determine_status( 

290 *, 

291 exitcode: int | None, 

292 memory: float | None, 

293 mle: float | None, 

294 match_result: bool | None, 

295) -> JudgeStatus: 

296 if exitcode is None: 

297 return JudgeStatus.TLE 

298 if memory is not None and mle is not None and memory > mle: 

299 return JudgeStatus.MLE 

300 if exitcode != 0: 

301 return JudgeStatus.RE 

302 if match_result is not None and not match_result: 

303 return JudgeStatus.WA 

304 return JudgeStatus.AC 

305 

306 

307def single_case( 

308 test_name: str, 

309 test_input_path: pathlib.Path, 

310 test_output_path: pathlib.Path, 

311 *, 

312 args: OjTestArguments, 

313) -> OjTestcaseResult: 

314 try: 

315 logger.info("%s: start", test_name) 

316 

317 # run the binary 

318 with test_input_path.open("rb") as infp: 

319 info = measure_command( 

320 args.command, 

321 env=args.env, 

322 stdin=infp, 

323 timeout=args.tle, 

324 gnu_time=True, 

325 ) 

326 answer = info.answer or "" 

327 elapsed: float = info.elapsed 

328 memory: float | None = info.memory 

329 

330 match_result = ( 

331 special_judge( 

332 str(args.problem.checker), 

333 answer, 

334 input_path=test_input_path, 

335 expected_output_path=test_output_path, 

336 ) 

337 if args.problem.checker 

338 else compare_answer(answer, test_output_path.read_text(), error=args.error) 

339 ) 

340 

341 status = determine_status( 

342 exitcode=info.returncode, 

343 memory=memory, 

344 mle=args.mle, 

345 match_result=match_result, 

346 ) 

347 

348 result = OjTestcaseResult( 

349 name=test_name, 

350 input=test_input_path, 

351 expected=test_output_path, 

352 answer=answer, 

353 status=status, 

354 exitcode=info.returncode, 

355 elapsed=elapsed, 

356 memory=memory, 

357 ) 

358 except CaseExecutionError: 

359 logger.exception( 

360 "Failed to run: %s", 

361 args, 

362 extra={"github": GitHubMessageParams()}, 

363 ) 

364 return OjTestcaseResult( 

365 name=test_name, 

366 input=test_input_path, 

367 expected=test_output_path, 

368 answer="", 

369 status=JudgeStatus.RE, 

370 exitcode=255, 

371 elapsed=0, 

372 memory=None, 

373 ) 

374 else: 

375 result.log() 

376 return result 

377 

378 

379def gnu_time_message(args: OjTestArguments): 

380 """Check wheather GNU time is available. 

381 

382 Show messages if GNU time is not available. 

383 """ 

384 if gnu.time_command() is None: 

385 if platform.system() == "Darwin": 

386 logger.info( 

387 "[HINT]: You can install GNU time with: $ brew install gnu-time", 

388 extra={"github": GitHubMessageParams()}, 

389 ) 

390 if args.mle is not None: 

391 logger.warning( 

392 "--mle is used but GNU time does not exist", 

393 extra={"github": GitHubMessageParams()}, 

394 ) 

395 

396 

397class _StatusCounter(Counter[JudgeStatus]): 

398 def __str__(self) -> str: 

399 return ", ".join( 

400 f"{cnt} {name}" 

401 for name, cnt in ((st.name, self.get(st)) for st in JudgeStatus) 

402 if cnt 

403 ) 

404 

405 

406def summarize(history: list[OjTestcaseResult]): 

407 elapsed: float = 0.0 

408 slowest: float = -1.0 

409 slowest_name: str | None = None 

410 heaviest: float = -1.0 

411 heaviest_name: str | None = None 

412 counter = _StatusCounter() 

413 for result in history: 

414 counter[result.status] += 1 

415 elapsed += result.elapsed 

416 if slowest < result.elapsed: 

417 slowest = result.elapsed 

418 slowest_name = result.name 

419 if result.memory is not None and heaviest < result.memory: 

420 heaviest = result.memory 

421 heaviest_name = result.name 

422 

423 # print the summary 

424 if slowest_name is not None: 

425 logger.info("slowest: %f sec (for %s)", slowest, slowest_name) 

426 if heaviest_name is not None: 

427 logger.info("max memory: %f MB (for %s)", heaviest, heaviest_name) 

428 

429 length = len(history) 

430 is_success = counter[JudgeStatus.AC] == length 

431 if is_success: 

432 logger.info("%s %d cases", green("SUCCESS"), length) 

433 else: 

434 logger.info("%s %s / %d cases", red("FAILURE"), counter, length) 

435 

436 # return the result 

437 return OjTestResult( 

438 is_success=is_success, 

439 slowest=slowest, 

440 elapsed=elapsed, 

441 heaviest=heaviest, 

442 testcases=history, 

443 ) 

444 

445 

446def _run(args: OjTestArguments) -> OjTestResult: 

447 gnu_time_message(args) 

448 

449 if args.error is not None and args.error > 1: 

450 logger.warning( 

451 "the tolerance is too large: relative = %s", 

452 args.error, 

453 extra={"github": GitHubMessageParams()}, 

454 ) 

455 

456 tests = list(args.problem.iter_system_cases()) 

457 

458 # run tests 

459 history: list[OjTestcaseResult] = [] 

460 for t in tests: 

461 if time.perf_counter() > args.deadline: 

462 raise VerifcationTimeoutError 

463 

464 history.append(single_case(t.name, t.input_path, t.output_path, args=args)) 

465 

466 return summarize(history) 

467 

468 

469def main( 

470 *, 

471 problem: TestCaseProvider, 

472 command: str | list[str], 

473 env: dict[str, str] | None, 

474 tle: float | None, 

475 mle: float | None, 

476 error: float | None, 

477 deadline: float = float("inf"), 

478) -> VerificationResult: 

479 args = OjTestArguments( 

480 command=command, 

481 problem=problem, 

482 env=env, 

483 tle=tle, 

484 mle=mle, 

485 error=error, 

486 deadline=deadline, 

487 ) 

488 result = _run(args) 

489 return VerificationResult( 

490 status=ResultStatus.SUCCESS if result.is_success else ResultStatus.FAILURE, 

491 elapsed=result.elapsed, 

492 slowest=result.slowest, 

493 heaviest=result.heaviest, 

494 testcases=[ 

495 TestcaseResult( 

496 name=case.name, 

497 elapsed=case.elapsed, 

498 memory=case.memory, 

499 status=case.status, 

500 ) 

501 for case in result.testcases 

502 ], 

503 )