Coverage for src / competitive_verifier / oj / oj_test.py: 100%
211 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
1import contextlib
2import math
3import os
4import pathlib
5import platform
6import shlex
7import shutil
8import signal
9import subprocess
10import sys
11import tempfile
12import time
13from collections import Counter
14from dataclasses import dataclass
15from logging import getLogger
16from typing import BinaryIO
18from competitive_verifier.log import GitHubMessageParams
19from competitive_verifier.models import (
20 JudgeStatus,
21 ResultStatus,
22 TestCaseProvider,
23 TestcaseResult,
24 VerifcationTimeoutError,
25 VerificationResult,
26)
28from . import gnu
29from .format import Printer, green, red
31logger = getLogger(__name__)
34class CaseExecutionError(Exception):
35 pass
38@dataclass
39class OjExecInfo:
40 answer: str | None
41 """The standard output of the executed command"""
42 elapsed: float
43 """The elapsed time of the executed command in seconds"""
44 memory: float | None
45 """The maximum memory usage of the executed command in megabytes"""
46 returncode: int | None
47 """The returncode of the executed command"""
50def measure_command(
51 command: list[str] | str,
52 *,
53 env: dict[str, str] | None = None,
54 stdin: BinaryIO | int | None = None,
55 timeout: float | None = None,
56 gnu_time: bool = False,
57) -> OjExecInfo:
58 if isinstance(command, str):
59 command = shlex.split(command)
61 if len(command) == 0:
62 raise CaseExecutionError
64 with gnu.GnuTimeWrapper(enabled=gnu_time) as gw:
65 if shutil.which(command[0]) is None:
66 raise CaseExecutionError
68 command = gw.get_command(command)
69 begin = time.perf_counter()
71 # We need kill processes called from the "time" command using process groups. Without this, orphans spawn. see https://github.com/kmyk/online-judge-tools/issues/640
72 start_new_session = gnu.time_command() is not None and os.name == "posix"
74 try:
75 if env:
76 env = os.environ | env
77 proc = subprocess.Popen(
78 command,
79 stdin=stdin,
80 stdout=subprocess.PIPE,
81 env=env,
82 stderr=sys.stderr,
83 encoding="utf-8",
84 start_new_session=start_new_session,
85 )
86 except Exception as e:
87 logger.exception(
88 "'%s' is not executable.",
89 command,
90 extra={"github": GitHubMessageParams()},
91 )
92 raise CaseExecutionError from e
94 try:
95 answer, _ = proc.communicate(timeout=timeout)
96 except subprocess.TimeoutExpired:
97 answer = None
98 finally: # pragma: no cover
99 if start_new_session:
100 with contextlib.suppress(ProcessLookupError):
101 os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
102 else:
103 proc.terminate()
105 end = time.perf_counter()
106 return OjExecInfo(
107 answer=answer,
108 elapsed=end - begin,
109 memory=gw.get_memory(),
110 returncode=proc.returncode,
111 )
114@dataclass
115class OjTestArguments:
116 """Parameters for oj-test command.
118 Port of onlinejudge_command.subcommand.test.add_subparser.
119 """
121 command: str | list[str]
122 problem: TestCaseProvider
123 tle: float | None
124 mle: float | None
125 error: float | None
126 env: dict[str, str] | None = None
127 deadline: float = float("inf")
130@dataclass
131class OjTestcaseResult:
132 name: str
133 """A name of the test case."""
134 input: pathlib.Path
135 """A input of the test case."""
136 answer: str
137 """A output of the test case."""
138 expected: pathlib.Path
139 """A expected output of the test case."""
141 status: JudgeStatus
142 elapsed: float
143 exitcode: int | None
145 memory: float | None = None
147 def __post_init__(self):
148 if not isinstance(self.exitcode, int):
149 self.exitcode = None
151 def __str__(self) -> str:
152 p = [
153 f"{self.name}: {green('AC')}"
154 if self.status == JudgeStatus.AC
155 else f"{self.name}: {red(self.status.name)}",
156 f"time: {self.elapsed:f} sec",
157 f"memory: {self.memory:f} MB" if self.memory is not None else None,
158 f"return code: {self.exitcode}" if self.exitcode else None,
159 ]
161 return ", ".join(filter(None, p))
163 def log(self):
164 match self.status:
165 case JudgeStatus.AC:
166 pass
167 case JudgeStatus.RE | JudgeStatus.TLE:
168 self._log_input()
169 self._log_expected()
170 case _:
171 self._log_input()
172 self._log_answer()
173 self._log_expected()
174 logger.info(self)
176 def _log_input(self) -> None:
177 logger.info("%s:input: %s", self.name, Printer(self.input))
179 def _log_expected(self) -> None:
180 logger.info("%s:expected: %s", self.name, Printer(self.expected))
182 def _log_answer(self) -> None:
183 logger.info("%s:answer: %s", self.name, Printer(self.answer))
186@dataclass
187class OjTestResult:
188 is_success: bool
190 elapsed: float
192 slowest: float
193 """max time [seconds]
194 """
196 heaviest: float
197 """max memory [MB]
198 """
200 testcases: list[OjTestcaseResult]
203def _try_parse_float(value: str) -> float | None:
204 try:
205 return float(value)
206 except ValueError:
207 return None
210def _equal_or_closed_float(actual: str, expected: str, *, error: float) -> bool:
211 if actual == expected:
212 return True
214 x = _try_parse_float(actual)
215 y = _try_parse_float(expected)
217 return (
218 x is not None
219 and y is not None
220 and math.isclose(x, y, rel_tol=error, abs_tol=error)
221 )
224def compare_answer(actual: str, expected: str, *, error: float | None) -> bool:
225 """Compare two byte strings.
227 Args:
228 actual (bytes): Actual output
229 expected (bytes): Expected output
230 error (float | None): Margin of error
231 Returns:
232 bool: True if they are considered equal
233 """
234 actual = actual.replace("\r\n", "\n")
235 expected = expected.replace("\r\n", "\n")
237 # match
238 if actual == expected:
239 return True
241 try:
242 if error is None:
243 actual_words = actual.split()
244 expected_words = expected.split()
245 if all(x == y for x, y in zip(actual_words, expected_words, strict=True)):
246 logger.warning("This was AC if spaces and newlines were ignored.")
247 return False
249 actual_lines = actual.rstrip("\n").split("\n")
250 expected_lines = expected.rstrip("\n").split("\n")
252 for actual_line, expected_line in zip(
253 actual_lines, expected_lines, strict=True
254 ):
255 actual_words = actual_line.split()
256 expected_words = expected_line.split()
258 for x, y in zip(actual_words, expected_words, strict=True):
259 if not _equal_or_closed_float(x, y, error=error):
260 return False
261 except ValueError:
262 return False
264 return True
267def special_judge(
268 judge_command: str,
269 output: str,
270 *,
271 input_path: pathlib.Path,
272 expected_output_path: pathlib.Path | None,
273) -> bool:
274 with tempfile.TemporaryDirectory() as tempdir:
275 actual_output_path = pathlib.Path(tempdir) / "actual.out"
276 actual_output_path.write_text(output)
278 command = [
279 *shlex.split(judge_command),
280 str(input_path.resolve()),
281 str(actual_output_path.resolve()),
282 str(
283 expected_output_path.resolve()
284 if expected_output_path is not None
285 else ""
286 ),
287 ]
289 logger.debug("$ %s", command)
290 info = measure_command(command)
291 logger.debug("judge's output: %s", Printer(info.answer or ""))
292 return info.returncode == 0
295def determine_status(
296 *,
297 exitcode: int | None,
298 memory: float | None,
299 mle: float | None,
300 match_result: bool | None,
301) -> JudgeStatus:
302 if exitcode is None:
303 return JudgeStatus.TLE
304 if memory is not None and mle is not None and memory > mle:
305 return JudgeStatus.MLE
306 if exitcode != 0:
307 return JudgeStatus.RE
308 if match_result is not None and not match_result:
309 return JudgeStatus.WA
310 return JudgeStatus.AC
313def single_case(
314 test_name: str,
315 test_input_path: pathlib.Path,
316 test_output_path: pathlib.Path,
317 *,
318 args: OjTestArguments,
319) -> OjTestcaseResult:
320 try:
321 logger.info("%s: start", test_name)
323 # run the binary
324 with test_input_path.open("rb") as infp:
325 info = measure_command(
326 args.command,
327 env=args.env,
328 stdin=infp,
329 timeout=args.tle,
330 gnu_time=True,
331 )
332 answer = info.answer or ""
333 elapsed: float = info.elapsed
334 memory: float | None = info.memory
336 match_result = (
337 special_judge(
338 str(args.problem.checker),
339 answer,
340 input_path=test_input_path,
341 expected_output_path=test_output_path,
342 )
343 if args.problem.checker
344 else compare_answer(answer, test_output_path.read_text(), error=args.error)
345 )
347 status = determine_status(
348 exitcode=info.returncode,
349 memory=memory,
350 mle=args.mle,
351 match_result=match_result,
352 )
354 result = OjTestcaseResult(
355 name=test_name,
356 input=test_input_path,
357 expected=test_output_path,
358 answer=answer,
359 status=status,
360 exitcode=info.returncode,
361 elapsed=elapsed,
362 memory=memory,
363 )
364 except CaseExecutionError:
365 logger.exception(
366 "Failed to run: %s",
367 args,
368 extra={"github": GitHubMessageParams()},
369 )
370 return OjTestcaseResult(
371 name=test_name,
372 input=test_input_path,
373 expected=test_output_path,
374 answer="",
375 status=JudgeStatus.RE,
376 exitcode=255,
377 elapsed=0,
378 memory=None,
379 )
380 else:
381 result.log()
382 return result
385def gnu_time_message(args: OjTestArguments):
386 """Check wheather GNU time is available.
388 Show messages if GNU time is not available.
389 """
390 if gnu.time_command() is None:
391 if platform.system() == "Darwin":
392 logger.info(
393 "[HINT]: You can install GNU time with: $ brew install gnu-time",
394 extra={"github": GitHubMessageParams()},
395 )
396 if args.mle is not None:
397 logger.warning(
398 "--mle is used but GNU time does not exist",
399 extra={"github": GitHubMessageParams()},
400 )
403class _StatusCounter(Counter[JudgeStatus]):
404 def __str__(self) -> str:
405 return ", ".join(
406 f"{cnt} {name}"
407 for name, cnt in ((st.name, self.get(st)) for st in JudgeStatus)
408 if cnt
409 )
412def summarize(history: list[OjTestcaseResult]):
413 elapsed: float = 0.0
414 slowest: float = -1.0
415 slowest_name: str | None = None
416 heaviest: float = -1.0
417 heaviest_name: str | None = None
418 counter = _StatusCounter()
419 for result in history:
420 counter[result.status] += 1
421 elapsed += result.elapsed
422 if slowest < result.elapsed:
423 slowest = result.elapsed
424 slowest_name = result.name
425 if result.memory is not None and heaviest < result.memory:
426 heaviest = result.memory
427 heaviest_name = result.name
429 # print the summary
430 if slowest_name is not None:
431 logger.info("slowest: %f sec (for %s)", slowest, slowest_name)
432 if heaviest_name is not None:
433 logger.info("max memory: %f MB (for %s)", heaviest, heaviest_name)
435 length = len(history)
436 is_success = counter[JudgeStatus.AC] == length
437 if is_success:
438 logger.info("%s %d cases", green("SUCCESS"), length)
439 else:
440 logger.info("%s %s / %d cases", red("FAILURE"), counter, length)
442 # return the result
443 return OjTestResult(
444 is_success=is_success,
445 slowest=slowest,
446 elapsed=elapsed,
447 heaviest=heaviest,
448 testcases=history,
449 )
452def _run(args: OjTestArguments) -> OjTestResult:
453 gnu_time_message(args)
455 if args.error is not None and args.error > 1:
456 logger.warning(
457 "the tolerance is too large: relative = %s",
458 args.error,
459 extra={"github": GitHubMessageParams()},
460 )
462 tests = list(args.problem.iter_system_cases())
464 # run tests
465 history: list[OjTestcaseResult] = []
466 for t in tests:
467 if time.perf_counter() > args.deadline:
468 raise VerifcationTimeoutError
470 history.append(single_case(t.name, t.input_path, t.output_path, args=args))
472 return summarize(history)
475def main(
476 *,
477 problem: TestCaseProvider,
478 command: str | list[str],
479 env: dict[str, str] | None,
480 tle: float | None,
481 mle: float | None,
482 error: float | None,
483 deadline: float = float("inf"),
484) -> VerificationResult:
485 args = OjTestArguments(
486 command=command,
487 problem=problem,
488 env=env,
489 tle=tle,
490 mle=mle,
491 error=error,
492 deadline=deadline,
493 )
494 result = _run(args)
495 return VerificationResult(
496 status=ResultStatus.SUCCESS if result.is_success else ResultStatus.FAILURE,
497 elapsed=result.elapsed,
498 slowest=result.slowest,
499 heaviest=result.heaviest,
500 testcases=[
501 TestcaseResult(
502 name=case.name,
503 elapsed=case.elapsed,
504 memory=case.memory,
505 status=case.status,
506 )
507 for case in result.testcases
508 ],
509 )