Coverage for src / competitive_verifier / oj / gnu.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-04-26 12:38 +0900

1import contextlib 

2import os 

3import pathlib 

4import re 

5import subprocess 

6import tempfile 

7from dataclasses import dataclass 

8from functools import cache 

9from logging import getLogger 

10from typing import Protocol 

11 

12logger = getLogger(__name__) 

13 

14 

15class GnuTimeRunner(Protocol): 

16 @property 

17 def gnu_time(self) -> str | None: ... 

18 

19 def get_command(self, command: list[str]) -> list[str]: ... 

20 def get_memory(self) -> float | None: 

21 """Return the amount of memory used, in megabytes, if possible.""" 

22 ... 

23 

24 def clean(self) -> None: ... 

25 

26 

27class _GnuTimeRunnerDummy: 

28 @property 

29 def gnu_time(self) -> None: 

30 return None 

31 

32 def get_command(self, command: list[str]) -> list[str]: 

33 return command 

34 

35 def get_memory(self) -> float | None: 

36 pass 

37 

38 def clean(self) -> None: 

39 pass 

40 

41 

42@dataclass 

43class _GnuTimeRunnerImpl: 

44 gnu_time: str 

45 tmpdir: tempfile.TemporaryDirectory[str] 

46 outfile: pathlib.Path 

47 

48 def get_command(self, command: list[str]) -> list[str]: 

49 return [self.gnu_time, "-f", "%M", "-o", str(self.outfile), "--", *command] 

50 

51 def get_memory(self) -> float | None: 

52 if self.outfile.exists() and ( 

53 report := self.outfile.read_text("utf-8").strip() 

54 ): 

55 logger.debug("GNU time says: %s", report) 

56 tail = report.splitlines()[-1] 

57 if tail.isdigit(): 

58 return int(tail) / 1000 

59 return None 

60 

61 def clean(self) -> None: 

62 self.tmpdir.cleanup() 

63 

64 

65class GnuTimeWrapper(contextlib.AbstractContextManager["GnuTimeRunner"]): 

66 _gnu_time: str | None 

67 _runner: GnuTimeRunner 

68 

69 def __init__(self, *, enabled: bool = True) -> None: 

70 super().__init__() 

71 self._gnu_time = time_command() if enabled else None 

72 self._runner = _GnuTimeRunnerDummy() 

73 

74 def __enter__(self): 

75 if self._gnu_time: 

76 tmpdir = tempfile.TemporaryDirectory() 

77 self._runner = _GnuTimeRunnerImpl( 

78 gnu_time=self._gnu_time, 

79 tmpdir=tmpdir, 

80 outfile=pathlib.Path(tmpdir.name) / "gnu_time_report.txt", 

81 ) 

82 return self._runner 

83 

84 def __exit__(self, *excinfo: object): 

85 self._runner.clean() 

86 

87 

88@cache 

89def time_command() -> str | None: 

90 cmds = ["time", "gtime"] 

91 if os.name == "posix": 

92 cmds += ["/bin/time", "/usr/bin/time"] 

93 return _find_gnu_time(cmds) 

94 

95 

96def _find_gnu_time(gnu_time_candidate: list[str]) -> str | None: 

97 for gnu_time in gnu_time_candidate: 

98 if check_gnu_time(gnu_time): 

99 return gnu_time 

100 return None 

101 

102 

103def check_gnu_time(gnu_time: str) -> bool: 

104 try: 

105 with tempfile.TemporaryDirectory() as td: 

106 tmp = pathlib.Path(td) / "out" 

107 ret = subprocess.run( 

108 [ 

109 gnu_time, 

110 "-f", 

111 "%M KB", 

112 "-o", 

113 str(tmp), 

114 "--", 

115 "echo", 

116 "check_gnu_time", 

117 ], 

118 encoding="utf-8", 

119 capture_output=True, 

120 check=True, 

121 ) 

122 if ( 

123 ret.returncode == 0 

124 and ret.stdout.rstrip() == "check_gnu_time" 

125 and re.match(r"^\d+ KB$", tmp.read_text("utf-8").strip()) 

126 ): 

127 return True 

128 except NameError: 

129 raise # NameError is not a runtime error caused by the environment, but a coding mistake 

130 except AttributeError: 

131 raise # AttributeError is also a mistake 

132 except Exception: # noqa: BLE001 

133 logger.debug("Failed to check gnu_time: %s", gnu_time, exc_info=True) 

134 return False