Coverage for src / competitive_verifier / oj / gnu.py: 96%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-05 16:00 +0000
1import abc
2import contextlib
3import os
4import pathlib
5import re
6import subprocess
7import tempfile
8from dataclasses import dataclass
9from functools import cache
10from logging import getLogger
12logger = getLogger(__name__)
15class GnuTimeRunner(abc.ABC):
16 @abc.abstractmethod
17 def get_command(self, command: list[str]) -> list[str]: ...
18 @abc.abstractmethod
19 def get_memory(self) -> float | None:
20 """Return the amount of memory used, in megabytes, if possible."""
21 ...
23 @abc.abstractmethod
24 def clean(self) -> None: ...
27class _GnuTimeRunnerDummy(GnuTimeRunner):
28 def get_command(self, command: list[str]) -> list[str]:
29 return command
31 def get_memory(self) -> float | None:
32 pass
34 def clean(self) -> None:
35 pass
38@dataclass
39class _GnuTimeRunnerImpl(GnuTimeRunner):
40 gnu_time: str
41 tmpdir: tempfile.TemporaryDirectory[str]
42 outfile: pathlib.Path
44 def get_command(self, command: list[str]) -> list[str]:
45 return [self.gnu_time, "-f", "%M", "-o", str(self.outfile), "--", *command]
47 def get_memory(self) -> float | None:
48 if self.outfile.exists() and (
49 report := self.outfile.read_text("utf-8").strip()
50 ):
51 logger.debug("GNU time says: %s", report)
52 tail = report.splitlines()[-1]
53 if tail.isdigit():
54 return int(tail) / 1000
55 return None
57 def clean(self) -> None:
58 self.tmpdir.cleanup()
61class GnuTimeWrapper(contextlib.AbstractContextManager["GnuTimeRunner"]):
62 _gnu_time: str | None
63 _runner: GnuTimeRunner
65 def __init__(self, *, enabled: bool = True) -> None:
66 super().__init__()
67 self._gnu_time = time_command() if enabled else None
68 self._runner = _GnuTimeRunnerDummy()
70 def __enter__(self):
71 if self._gnu_time:
72 tmpdir = tempfile.TemporaryDirectory()
73 self._runner = _GnuTimeRunnerImpl(
74 gnu_time=self._gnu_time,
75 tmpdir=tmpdir,
76 outfile=pathlib.Path(tmpdir.name) / "gnu_time_report.txt",
77 )
78 return self._runner
80 def __exit__(self, *excinfo: object):
81 self._runner.clean()
84@cache
85def time_command() -> str | None:
86 cmds = ["time", "gtime"]
87 if os.name == "posix":
88 cmds += ["/bin/time", "/usr/bin/time"]
89 return _find_gnu_time(cmds)
92def _find_gnu_time(gnu_time_candidate: list[str]) -> str | None:
93 for gnu_time in gnu_time_candidate: 93 ↛ 96line 93 didn't jump to line 96 because the loop on line 93 didn't complete
94 if check_gnu_time(gnu_time): 94 ↛ 93line 94 didn't jump to line 93 because the condition on line 94 was always true
95 return gnu_time
96 return None
99def check_gnu_time(gnu_time: str) -> bool:
100 try:
101 with tempfile.TemporaryDirectory() as td:
102 tmp = pathlib.Path(td) / "out"
103 ret = subprocess.run(
104 [
105 gnu_time,
106 "-f",
107 "%M KB",
108 "-o",
109 str(tmp),
110 "--",
111 "echo",
112 "check_gnu_time",
113 ],
114 encoding="utf-8",
115 capture_output=True,
116 check=True,
117 )
118 if (
119 ret.returncode == 0
120 and ret.stdout.rstrip() == "check_gnu_time"
121 and re.match(r"^\d+ KB$", tmp.read_text("utf-8").strip())
122 ):
123 return True
124 except NameError:
125 raise # NameError is not a runtime error caused by the environment, but a coding mistake
126 except AttributeError:
127 raise # AttributeError is also a mistake
128 except Exception: # noqa: BLE001
129 logger.debug("Failed to check gnu_time: %s", gnu_time, exc_info=True)
130 return False