Coverage for src / competitive_verifier / oj / gnu.py: 100%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-26 12:38 +0900
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-26 12:38 +0900
1import contextlib
2import os
3import pathlib
4import re
5import subprocess
6import tempfile
7from dataclasses import dataclass
8from functools import cache
9from logging import getLogger
10from typing import Protocol
12logger = getLogger(__name__)
15class GnuTimeRunner(Protocol):
16 @property
17 def gnu_time(self) -> str | None: ...
19 def get_command(self, command: list[str]) -> list[str]: ...
20 def get_memory(self) -> float | None:
21 """Return the amount of memory used, in megabytes, if possible."""
22 ...
24 def clean(self) -> None: ...
27class _GnuTimeRunnerDummy:
28 @property
29 def gnu_time(self) -> None:
30 return None
32 def get_command(self, command: list[str]) -> list[str]:
33 return command
35 def get_memory(self) -> float | None:
36 pass
38 def clean(self) -> None:
39 pass
42@dataclass
43class _GnuTimeRunnerImpl:
44 gnu_time: str
45 tmpdir: tempfile.TemporaryDirectory[str]
46 outfile: pathlib.Path
48 def get_command(self, command: list[str]) -> list[str]:
49 return [self.gnu_time, "-f", "%M", "-o", str(self.outfile), "--", *command]
51 def get_memory(self) -> float | None:
52 if self.outfile.exists() and (
53 report := self.outfile.read_text("utf-8").strip()
54 ):
55 logger.debug("GNU time says: %s", report)
56 tail = report.splitlines()[-1]
57 if tail.isdigit():
58 return int(tail) / 1000
59 return None
61 def clean(self) -> None:
62 self.tmpdir.cleanup()
65class GnuTimeWrapper(contextlib.AbstractContextManager["GnuTimeRunner"]):
66 _gnu_time: str | None
67 _runner: GnuTimeRunner
69 def __init__(self, *, enabled: bool = True) -> None:
70 super().__init__()
71 self._gnu_time = time_command() if enabled else None
72 self._runner = _GnuTimeRunnerDummy()
74 def __enter__(self):
75 if self._gnu_time:
76 tmpdir = tempfile.TemporaryDirectory()
77 self._runner = _GnuTimeRunnerImpl(
78 gnu_time=self._gnu_time,
79 tmpdir=tmpdir,
80 outfile=pathlib.Path(tmpdir.name) / "gnu_time_report.txt",
81 )
82 return self._runner
84 def __exit__(self, *excinfo: object):
85 self._runner.clean()
88@cache
89def time_command() -> str | None:
90 cmds = ["time", "gtime"]
91 if os.name == "posix":
92 cmds += ["/bin/time", "/usr/bin/time"]
93 return _find_gnu_time(cmds)
96def _find_gnu_time(gnu_time_candidate: list[str]) -> str | None:
97 for gnu_time in gnu_time_candidate:
98 if check_gnu_time(gnu_time):
99 return gnu_time
100 return None
103def check_gnu_time(gnu_time: str) -> bool:
104 try:
105 with tempfile.TemporaryDirectory() as td:
106 tmp = pathlib.Path(td) / "out"
107 ret = subprocess.run(
108 [
109 gnu_time,
110 "-f",
111 "%M KB",
112 "-o",
113 str(tmp),
114 "--",
115 "echo",
116 "check_gnu_time",
117 ],
118 encoding="utf-8",
119 capture_output=True,
120 check=True,
121 )
122 if (
123 ret.returncode == 0
124 and ret.stdout.rstrip() == "check_gnu_time"
125 and re.match(r"^\d+ KB$", tmp.read_text("utf-8").strip())
126 ):
127 return True
128 except NameError:
129 raise # NameError is not a runtime error caused by the environment, but a coding mistake
130 except AttributeError:
131 raise # AttributeError is also a mistake
132 except Exception: # noqa: BLE001
133 logger.debug("Failed to check gnu_time: %s", gnu_time, exc_info=True)
134 return False