Coverage for src / competitive_verifier / oj / oj_test.py: 100%
210 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-26 12:38 +0900
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-26 12:38 +0900
1import math
2import os
3import pathlib
4import platform
5import shlex
6import shutil
7import subprocess
8import sys
9import tempfile
10import time
11from collections import Counter
12from dataclasses import dataclass
13from logging import getLogger
14from typing import BinaryIO
16from competitive_verifier.log import GitHubMessageParams
17from competitive_verifier.models import (
18 JudgeStatus,
19 ResultStatus,
20 TestCaseProvider,
21 TestcaseResult,
22 VerifcationTimeoutError,
23 VerificationResult,
24)
26from . import gnu
27from .format import Printer, green, red
29logger = getLogger(__name__)
32class CaseExecutionError(Exception):
33 pass
36@dataclass
37class OjExecInfo:
38 answer: str | None
39 """The standard output of the executed command"""
40 elapsed: float
41 """The elapsed time of the executed command in seconds"""
42 memory: float | None
43 """The maximum memory usage of the executed command in megabytes"""
44 returncode: int | None
45 """The returncode of the executed command"""
48def measure_command(
49 command: list[str] | str,
50 *,
51 env: dict[str, str] | None = None,
52 stdin: BinaryIO | int | None = None,
53 timeout: float | None = None,
54 gnu_time: bool = False,
55) -> OjExecInfo:
56 if isinstance(command, str):
57 command = shlex.split(command)
59 if len(command) == 0:
60 raise CaseExecutionError
62 with gnu.GnuTimeWrapper(enabled=gnu_time) as gw:
63 if shutil.which(command[0]) is None:
64 raise CaseExecutionError
66 command = gw.get_command(command)
67 begin = time.perf_counter()
69 # We need kill processes called from the "time" command using process groups. Without this, orphans spawn. see https://github.com/kmyk/online-judge-tools/issues/640
70 start_new_session = gw.gnu_time is not None
72 try:
73 if env is not None:
74 env = os.environ | env
75 proc = subprocess.run(
76 command,
77 env=env,
78 timeout=timeout,
79 stdin=stdin,
80 stdout=subprocess.PIPE,
81 stderr=sys.stderr,
82 encoding="utf-8",
83 start_new_session=start_new_session,
84 check=False,
85 )
86 answer = proc.stdout
87 returncode = proc.returncode
88 except subprocess.TimeoutExpired:
89 answer = None
90 returncode = None
91 except Exception as e:
92 logger.exception(
93 "'%s' is not executable.",
94 command,
95 extra={"github": GitHubMessageParams()},
96 )
97 raise CaseExecutionError from e
99 end = time.perf_counter()
100 return OjExecInfo(
101 answer=answer,
102 elapsed=end - begin,
103 memory=gw.get_memory(),
104 returncode=returncode,
105 )
108@dataclass
109class OjTestArguments:
110 """Parameters for oj-test command.
112 Port of onlinejudge_command.subcommand.test.add_subparser.
113 """
115 command: str | list[str]
116 problem: TestCaseProvider
117 tle: float | None
118 mle: float | None
119 error: float | None
120 env: dict[str, str] | None = None
121 deadline: float = float("inf")
124@dataclass
125class OjTestcaseResult:
126 name: str
127 """A name of the test case."""
128 input: pathlib.Path
129 """A input of the test case."""
130 answer: str
131 """A output of the test case."""
132 expected: pathlib.Path
133 """A expected output of the test case."""
135 status: JudgeStatus
136 elapsed: float
137 exitcode: int | None
139 memory: float | None = None
141 def __post_init__(self):
142 if not isinstance(self.exitcode, int):
143 self.exitcode = None
145 def __str__(self) -> str:
146 p = [
147 f"{self.name}: {green('AC')}"
148 if self.status == JudgeStatus.AC
149 else f"{self.name}: {red(self.status.name)}",
150 f"time: {self.elapsed:f} sec",
151 f"memory: {self.memory:f} MB" if self.memory is not None else None,
152 f"return code: {self.exitcode}" if self.exitcode else None,
153 ]
155 return ", ".join(filter(None, p))
157 def log(self):
158 match self.status:
159 case JudgeStatus.AC:
160 pass
161 case JudgeStatus.RE | JudgeStatus.TLE:
162 self._log_input()
163 self._log_expected()
164 case _:
165 self._log_input()
166 self._log_answer()
167 self._log_expected()
168 logger.info(self)
170 def _log_input(self) -> None:
171 logger.info("%s:input: %s", self.name, Printer(self.input))
173 def _log_expected(self) -> None:
174 logger.info("%s:expected: %s", self.name, Printer(self.expected))
176 def _log_answer(self) -> None:
177 logger.info("%s:answer: %s", self.name, Printer(self.answer))
180@dataclass
181class OjTestResult:
182 is_success: bool
184 elapsed: float
186 slowest: float
187 """max time [seconds]
188 """
190 heaviest: float
191 """max memory [MB]
192 """
194 testcases: list[OjTestcaseResult]
197def _try_parse_float(value: str) -> float | None:
198 try:
199 return float(value)
200 except ValueError:
201 return None
204def _equal_or_closed_float(actual: str, expected: str, *, error: float) -> bool:
205 if actual == expected:
206 return True
208 x = _try_parse_float(actual)
209 y = _try_parse_float(expected)
211 return (
212 x is not None
213 and y is not None
214 and math.isclose(x, y, rel_tol=error, abs_tol=error)
215 )
218def compare_answer(actual: str, expected: str, *, error: float | None) -> bool:
219 """Compare two byte strings.
221 Args:
222 actual (bytes): Actual output
223 expected (bytes): Expected output
224 error (float | None): Margin of error
225 Returns:
226 bool: True if they are considered equal
227 """
228 actual = actual.replace("\r\n", "\n")
229 expected = expected.replace("\r\n", "\n")
231 # match
232 if actual == expected:
233 return True
235 try:
236 if error is None:
237 actual_words = actual.split()
238 expected_words = expected.split()
239 if all(x == y for x, y in zip(actual_words, expected_words, strict=True)):
240 logger.warning("This was AC if spaces and newlines were ignored.")
241 return False
243 actual_lines = actual.rstrip("\n").split("\n")
244 expected_lines = expected.rstrip("\n").split("\n")
246 for actual_line, expected_line in zip(
247 actual_lines, expected_lines, strict=True
248 ):
249 actual_words = actual_line.split()
250 expected_words = expected_line.split()
252 for x, y in zip(actual_words, expected_words, strict=True):
253 if not _equal_or_closed_float(x, y, error=error):
254 return False
255 except ValueError:
256 return False
258 return True
261def special_judge(
262 judge_command: str,
263 output: str,
264 *,
265 input_path: pathlib.Path,
266 expected_output_path: pathlib.Path | None,
267) -> bool:
268 with tempfile.TemporaryDirectory() as tempdir:
269 actual_output_path = pathlib.Path(tempdir) / "actual.out"
270 actual_output_path.write_text(output)
272 command = [
273 *shlex.split(judge_command),
274 str(input_path.resolve()),
275 str(actual_output_path.resolve()),
276 str(
277 expected_output_path.resolve()
278 if expected_output_path is not None
279 else ""
280 ),
281 ]
283 logger.debug("$ %s", command)
284 info = measure_command(command)
285 logger.debug("judge's output: %s", Printer(info.answer or ""))
286 return info.returncode == 0
289def determine_status(
290 *,
291 exitcode: int | None,
292 memory: float | None,
293 mle: float | None,
294 match_result: bool | None,
295) -> JudgeStatus:
296 if exitcode is None:
297 return JudgeStatus.TLE
298 if memory is not None and mle is not None and memory > mle:
299 return JudgeStatus.MLE
300 if exitcode != 0:
301 return JudgeStatus.RE
302 if match_result is not None and not match_result:
303 return JudgeStatus.WA
304 return JudgeStatus.AC
307def single_case(
308 test_name: str,
309 test_input_path: pathlib.Path,
310 test_output_path: pathlib.Path,
311 *,
312 args: OjTestArguments,
313) -> OjTestcaseResult:
314 try:
315 logger.info("%s: start", test_name)
317 # run the binary
318 with test_input_path.open("rb") as infp:
319 info = measure_command(
320 args.command,
321 env=args.env,
322 stdin=infp,
323 timeout=args.tle,
324 gnu_time=True,
325 )
326 answer = info.answer or ""
327 elapsed: float = info.elapsed
328 memory: float | None = info.memory
330 match_result = (
331 special_judge(
332 str(args.problem.checker),
333 answer,
334 input_path=test_input_path,
335 expected_output_path=test_output_path,
336 )
337 if args.problem.checker
338 else compare_answer(answer, test_output_path.read_text(), error=args.error)
339 )
341 status = determine_status(
342 exitcode=info.returncode,
343 memory=memory,
344 mle=args.mle,
345 match_result=match_result,
346 )
348 result = OjTestcaseResult(
349 name=test_name,
350 input=test_input_path,
351 expected=test_output_path,
352 answer=answer,
353 status=status,
354 exitcode=info.returncode,
355 elapsed=elapsed,
356 memory=memory,
357 )
358 except CaseExecutionError:
359 logger.exception(
360 "Failed to run: %s",
361 args,
362 extra={"github": GitHubMessageParams()},
363 )
364 return OjTestcaseResult(
365 name=test_name,
366 input=test_input_path,
367 expected=test_output_path,
368 answer="",
369 status=JudgeStatus.RE,
370 exitcode=255,
371 elapsed=0,
372 memory=None,
373 )
374 else:
375 result.log()
376 return result
379def gnu_time_message(args: OjTestArguments):
380 """Check wheather GNU time is available.
382 Show messages if GNU time is not available.
383 """
384 if gnu.time_command() is None:
385 if platform.system() == "Darwin":
386 logger.info(
387 "[HINT]: You can install GNU time with: $ brew install gnu-time",
388 extra={"github": GitHubMessageParams()},
389 )
390 if args.mle is not None:
391 logger.warning(
392 "--mle is used but GNU time does not exist",
393 extra={"github": GitHubMessageParams()},
394 )
397class _StatusCounter(Counter[JudgeStatus]):
398 def __str__(self) -> str:
399 return ", ".join(
400 f"{cnt} {name}"
401 for name, cnt in ((st.name, self.get(st)) for st in JudgeStatus)
402 if cnt
403 )
406def summarize(history: list[OjTestcaseResult]):
407 elapsed: float = 0.0
408 slowest: float = -1.0
409 slowest_name: str | None = None
410 heaviest: float = -1.0
411 heaviest_name: str | None = None
412 counter = _StatusCounter()
413 for result in history:
414 counter[result.status] += 1
415 elapsed += result.elapsed
416 if slowest < result.elapsed:
417 slowest = result.elapsed
418 slowest_name = result.name
419 if result.memory is not None and heaviest < result.memory:
420 heaviest = result.memory
421 heaviest_name = result.name
423 # print the summary
424 if slowest_name is not None:
425 logger.info("slowest: %f sec (for %s)", slowest, slowest_name)
426 if heaviest_name is not None:
427 logger.info("max memory: %f MB (for %s)", heaviest, heaviest_name)
429 length = len(history)
430 is_success = counter[JudgeStatus.AC] == length
431 if is_success:
432 logger.info("%s %d cases", green("SUCCESS"), length)
433 else:
434 logger.info("%s %s / %d cases", red("FAILURE"), counter, length)
436 # return the result
437 return OjTestResult(
438 is_success=is_success,
439 slowest=slowest,
440 elapsed=elapsed,
441 heaviest=heaviest,
442 testcases=history,
443 )
446def _run(args: OjTestArguments) -> OjTestResult:
447 gnu_time_message(args)
449 if args.error is not None and args.error > 1:
450 logger.warning(
451 "the tolerance is too large: relative = %s",
452 args.error,
453 extra={"github": GitHubMessageParams()},
454 )
456 tests = list(args.problem.iter_system_cases())
458 # run tests
459 history: list[OjTestcaseResult] = []
460 for t in tests:
461 if time.perf_counter() > args.deadline:
462 raise VerifcationTimeoutError
464 history.append(single_case(t.name, t.input_path, t.output_path, args=args))
466 return summarize(history)
469def main(
470 *,
471 problem: TestCaseProvider,
472 command: str | list[str],
473 env: dict[str, str] | None,
474 tle: float | None,
475 mle: float | None,
476 error: float | None,
477 deadline: float = float("inf"),
478) -> VerificationResult:
479 args = OjTestArguments(
480 command=command,
481 problem=problem,
482 env=env,
483 tle=tle,
484 mle=mle,
485 error=error,
486 deadline=deadline,
487 )
488 result = _run(args)
489 return VerificationResult(
490 status=ResultStatus.SUCCESS if result.is_success else ResultStatus.FAILURE,
491 elapsed=result.elapsed,
492 slowest=result.slowest,
493 heaviest=result.heaviest,
494 testcases=[
495 TestcaseResult(
496 name=case.name,
497 elapsed=case.elapsed,
498 memory=case.memory,
499 status=case.status,
500 )
501 for case in result.testcases
502 ],
503 )