Coverage for src / competitive_verifier / oj / gnu.py: 96%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-05 16:00 +0000

1import abc 

2import contextlib 

3import os 

4import pathlib 

5import re 

6import subprocess 

7import tempfile 

8from dataclasses import dataclass 

9from functools import cache 

10from logging import getLogger 

11 

12logger = getLogger(__name__) 

13 

14 

15class GnuTimeRunner(abc.ABC): 

16 @abc.abstractmethod 

17 def get_command(self, command: list[str]) -> list[str]: ... 

18 @abc.abstractmethod 

19 def get_memory(self) -> float | None: 

20 """Return the amount of memory used, in megabytes, if possible.""" 

21 ... 

22 

23 @abc.abstractmethod 

24 def clean(self) -> None: ... 

25 

26 

27class _GnuTimeRunnerDummy(GnuTimeRunner): 

28 def get_command(self, command: list[str]) -> list[str]: 

29 return command 

30 

31 def get_memory(self) -> float | None: 

32 pass 

33 

34 def clean(self) -> None: 

35 pass 

36 

37 

38@dataclass 

39class _GnuTimeRunnerImpl(GnuTimeRunner): 

40 gnu_time: str 

41 tmpdir: tempfile.TemporaryDirectory[str] 

42 outfile: pathlib.Path 

43 

44 def get_command(self, command: list[str]) -> list[str]: 

45 return [self.gnu_time, "-f", "%M", "-o", str(self.outfile), "--", *command] 

46 

47 def get_memory(self) -> float | None: 

48 if self.outfile.exists() and ( 

49 report := self.outfile.read_text("utf-8").strip() 

50 ): 

51 logger.debug("GNU time says: %s", report) 

52 tail = report.splitlines()[-1] 

53 if tail.isdigit(): 

54 return int(tail) / 1000 

55 return None 

56 

57 def clean(self) -> None: 

58 self.tmpdir.cleanup() 

59 

60 

61class GnuTimeWrapper(contextlib.AbstractContextManager["GnuTimeRunner"]): 

62 _gnu_time: str | None 

63 _runner: GnuTimeRunner 

64 

65 def __init__(self, *, enabled: bool = True) -> None: 

66 super().__init__() 

67 self._gnu_time = time_command() if enabled else None 

68 self._runner = _GnuTimeRunnerDummy() 

69 

70 def __enter__(self): 

71 if self._gnu_time: 

72 tmpdir = tempfile.TemporaryDirectory() 

73 self._runner = _GnuTimeRunnerImpl( 

74 gnu_time=self._gnu_time, 

75 tmpdir=tmpdir, 

76 outfile=pathlib.Path(tmpdir.name) / "gnu_time_report.txt", 

77 ) 

78 return self._runner 

79 

80 def __exit__(self, *excinfo: object): 

81 self._runner.clean() 

82 

83 

84@cache 

85def time_command() -> str | None: 

86 cmds = ["time", "gtime"] 

87 if os.name == "posix": 

88 cmds += ["/bin/time", "/usr/bin/time"] 

89 return _find_gnu_time(cmds) 

90 

91 

92def _find_gnu_time(gnu_time_candidate: list[str]) -> str | None: 

93 for gnu_time in gnu_time_candidate: 93 ↛ 96line 93 didn't jump to line 96 because the loop on line 93 didn't complete

94 if check_gnu_time(gnu_time): 94 ↛ 93line 94 didn't jump to line 93 because the condition on line 94 was always true

95 return gnu_time 

96 return None 

97 

98 

99def check_gnu_time(gnu_time: str) -> bool: 

100 try: 

101 with tempfile.TemporaryDirectory() as td: 

102 tmp = pathlib.Path(td) / "out" 

103 ret = subprocess.run( 

104 [ 

105 gnu_time, 

106 "-f", 

107 "%M KB", 

108 "-o", 

109 str(tmp), 

110 "--", 

111 "echo", 

112 "check_gnu_time", 

113 ], 

114 encoding="utf-8", 

115 capture_output=True, 

116 check=True, 

117 ) 

118 if ( 

119 ret.returncode == 0 

120 and ret.stdout.rstrip() == "check_gnu_time" 

121 and re.match(r"^\d+ KB$", tmp.read_text("utf-8").strip()) 

122 ): 

123 return True 

124 except NameError: 

125 raise # NameError is not a runtime error caused by the environment, but a coding mistake 

126 except AttributeError: 

127 raise # AttributeError is also a mistake 

128 except Exception: # noqa: BLE001 

129 logger.debug("Failed to check gnu_time: %s", gnu_time, exc_info=True) 

130 return False