Coverage for src / competitive_verifier / verify / main.py: 100%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-05 16:00 +0000

1import math 

2import pathlib 

3from argparse import ArgumentParser 

4from functools import cached_property 

5from logging import getLogger 

6from typing import Literal 

7 

8from pydantic import Field, field_validator 

9 

10from competitive_verifier import github 

11from competitive_verifier.arg import ( 

12 IgnoreErrorArguments, 

13 VerboseArguments, 

14 VerifyFilesJsonArguments, 

15 WriteSummaryArguments, 

16) 

17from competitive_verifier.log import GitHubMessageParams 

18from competitive_verifier.models import VerificationInput, VerifyCommandResult 

19 

20from .verifier import SplitState, Verifier 

21 

22logger = getLogger(__name__) 

23 

24 

25class Verify( 

26 WriteSummaryArguments, 

27 IgnoreErrorArguments, 

28 VerifyFilesJsonArguments, 

29 VerboseArguments, 

30): 

31 subcommand: Literal["verify"] = Field( 

32 default="verify", 

33 description="Verify library", 

34 ) 

35 timeout: float = math.inf 

36 default_tle: float | None = None 

37 default_mle: float | None = None 

38 

39 prev_result: pathlib.Path | None = None 

40 

41 download: bool = True 

42 

43 output: pathlib.Path | None = None 

44 

45 split: int | None = None 

46 split_index: int | None = None 

47 

48 def read_prev_result(self) -> VerifyCommandResult | None: 

49 if not self.prev_result: 

50 return None 

51 try: 

52 return VerifyCommandResult.parse_file_relative(self.prev_result) 

53 except Exception: 

54 logger.warning( 

55 "Failed to parse prev_result: %s", 

56 self.prev_result, 

57 extra={"github": GitHubMessageParams(file=self.prev_result)}, 

58 ) 

59 

60 def write_result(self, result: VerifyCommandResult): 

61 super().write_result(result) 

62 

63 result_json = result.model_dump_json(exclude_none=True) 

64 print(result_json) 

65 

66 if self.output: 

67 self.output.parent.mkdir(parents=True, exist_ok=True) 

68 self.output.write_text(result_json, encoding="utf-8") 

69 

70 @field_validator("timeout", mode="after") 

71 @classmethod 

72 def timeout_zero_equals_inf(cls, value: float) -> float: 

73 if value == 0: 

74 return math.inf 

75 return value 

76 

77 @cached_property 

78 def split_state(self) -> SplitState | None: 

79 split = self.split 

80 split_index = self.split_index 

81 match (split_index, split): 

82 case (int(), int()): 

83 if split <= 0: 

84 raise ValueError("--split must be greater than 0.") 

85 if not (0 <= split_index < split): 

86 raise ValueError( 

87 "--split-index must be greater than 0 and less than --split." 

88 ) 

89 return SplitState(size=split, index=split_index) 

90 case (None, int()): 

91 raise ValueError("--split argument requires --split-index argument.") 

92 case (int(), None): 

93 raise ValueError("--split-index argument requires --split argument.") 

94 case _: 

95 return None 

96 

97 @classmethod 

98 def add_parser(cls, parser: ArgumentParser): 

99 super().add_parser(parser) 

100 parser.add_argument( 

101 "--timeout", 

102 type=float, 

103 default=math.inf, 

104 help="Timeout seconds. if value is zero, it is same to math.inf.", 

105 ) 

106 parser.add_argument( 

107 "--tle", 

108 dest="default_tle", 

109 type=float, 

110 default=None, 

111 help="Threshold seconds to be TLE", 

112 ) 

113 parser.add_argument( 

114 "--mle", 

115 dest="default_mle", 

116 type=float, 

117 default=None, 

118 help="Threshold memory usage (MB) to be MLE", 

119 ) 

120 parser.add_argument( 

121 "--prev-result", 

122 type=pathlib.Path, 

123 required=False, 

124 help="Previous result json file", 

125 ) 

126 

127 parser.add_argument( 

128 "--no-download", 

129 action="store_false", 

130 dest="download", 

131 help="Suppress `oj download`", 

132 ) 

133 parser.add_argument( 

134 "--output", 

135 "-o", 

136 type=pathlib.Path, 

137 required=False, 

138 help="The output file for which verifier saves the result json.", 

139 ) 

140 parallel_group = parser.add_argument_group("parallel") 

141 parallel_group.add_argument( 

142 "--split", 

143 type=int, 

144 help="Parallel job size", 

145 required=False, 

146 ) 

147 parallel_group.add_argument( 

148 "--split-index", 

149 type=int, 

150 help="Parallel job index", 

151 required=False, 

152 ) 

153 

154 def _run(self) -> bool: 

155 logger.debug("arguments:%s", self) 

156 logger.info("verify_files_json=%s", self.verify_files_json) 

157 verifications = VerificationInput.parse_file_relative(self.verify_files_json) 

158 prev_result = self.read_prev_result() 

159 

160 verifier = Verifier( 

161 verifications, 

162 use_git_timestamp=github.env.is_in_github_actions(), 

163 timeout=self.timeout, 

164 default_tle=self.default_tle, 

165 default_mle=self.default_mle, 

166 prev_result=prev_result, 

167 split_state=self.split_state, 

168 ) 

169 result = verifier.verify(download=self.download) 

170 self.write_result(result) 

171 

172 is_success = result.is_success() 

173 

174 if is_success: 

175 logger.info("success!") 

176 else: 

177 logger.warning("not success!") 

178 

179 return is_success or self.ignore_error