Coverage for mindsdb / api / mysql / mysql_proxy / data_types / mysql_datum.py: 11%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1""" 

2******************************************************* 

3 * Copyright (C) 2017 MindsDB Inc. <copyright@mindsdb.com> 

4 * 

5 * This file is part of MindsDB Server. 

6 * 

7 * MindsDB Server can not be copied and/or distributed without the express 

8 * permission of MindsDB Inc 

9 ******************************************************* 

10""" 

11 

12import struct 

13 

14from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import ( 

15 DEFAULT_CAPABILITIES, 

16 NULL_VALUE, 

17 ONE_BYTE_ENC, 

18 THREE_BYTE_ENC, 

19 TWO_BYTE_ENC, 

20) 

21from mindsdb.utilities import log 

22 

23logger = log.getLogger(__name__) 

24 

25NULL_VALUE_INT = ord(NULL_VALUE) 

26 

27 

28class Datum: 

29 __slots__ = ["value", "var_type", "var_len"] 

30 

31 def __init__(self, var_type, value=None, var_len=None): 

32 # TODO other types: float, timestamp 

33 self.value = b"" 

34 

35 if var_len is None: 

36 idx = var_type.find("<") 

37 var_len = var_type[idx + 1 : -1] 

38 var_type = var_type[:idx] 

39 self.var_type = var_type 

40 self.var_len = var_len 

41 

42 if value is not None: 

43 self.set(value) 

44 

45 def set(self, value): 

46 self.value = value 

47 

48 def setFromBuff(self, buff): 

49 if self.var_len == "lenenc": 

50 start = 1 

51 ln_enc = buff[0] 

52 if int(ln_enc) <= ONE_BYTE_ENC[0]: 

53 start = 0 

54 end = 1 

55 elif int(ln_enc) == TWO_BYTE_ENC[0]: 

56 end = 3 

57 elif int(ln_enc) == THREE_BYTE_ENC[0]: 

58 end = 4 

59 elif ln_enc: 

60 end = 9 

61 

62 num_str = buff[start:end] 

63 if end > 9: 

64 logger.error("Cant decode integer greater than 8 bytes") 

65 return buff[end - 1 :] # noqa: E203 

66 

67 for j in range(8 - (end - start)): 

68 num_str += b"\0" 

69 

70 if self.var_type == "int": 

71 self.value = struct.unpack("i", num_str) 

72 return buff[end:] 

73 

74 if self.var_type in ["byte", "string"]: 

75 length = struct.unpack("Q", num_str)[0] 

76 self.value = buff[end : (length + end)] # noqa: E203 

77 return buff[(length + end) :] # noqa: E203 

78 

79 if self.var_len == "EOF": 

80 length = len(buff) 

81 self.var_len = str(length) 

82 self.value = buff 

83 return "" 

84 else: 

85 length = self.var_len 

86 

87 if self.var_type == "string" and self.var_len == "NUL": 

88 for j, x in enumerate(buff): 

89 if int(x) == 0: 

90 length = j + 1 

91 break 

92 

93 length = int(length) 

94 if self.var_type in ["byte", "string"]: 

95 end = length 

96 self.value = buff[:end] 

97 else: # if its an integer 

98 end = length 

99 num_str = buff[:end] 

100 if end > 8: 

101 logger.error("cant decode integer greater than 8 bytes") 

102 return buff[end:] 

103 for j in range(8 - end): 

104 num_str += b"\0" 

105 self.value = struct.unpack("Q", num_str)[0] 

106 if str(self.var_len) == "NUL": 

107 self.value = self.value[:-1] 

108 return buff[end:] 

109 

110 @classmethod 

111 def serialize_int(cls, value): 

112 if value is None: 

113 return NULL_VALUE 

114 

115 byte_count = -(value.bit_length() // (-8)) 

116 

117 if byte_count == 0: 

118 return b"\0" 

119 if value < NULL_VALUE_INT: 

120 return struct.pack("B", value) 

121 if value >= NULL_VALUE_INT and byte_count <= 2: 

122 return TWO_BYTE_ENC + struct.pack("H", value) 

123 if byte_count <= 3: 

124 return THREE_BYTE_ENC + struct.pack("i", value)[:3] 

125 if byte_count <= 8: 

126 return THREE_BYTE_ENC + struct.pack("Q", value) 

127 

128 def toStringPacket(self): 

129 return self.get_serializer()(self.value) 

130 

131 def get_serializer(self): 

132 if self.var_type in ("string", "byte"): 

133 if self.var_len == "lenenc": 

134 if isinstance(self.value, bytes): 

135 return self.serialize_bytes 

136 return self.serialize_str 

137 if self.var_len == "EOF": 

138 return self.serialize_str_eof 

139 if self.var_len == "NUL": 

140 return lambda v: bytes(v, "utf-8") + struct.pack("b", 0) 

141 if self.var_len == "packet": 

142 return lambda v: v.get_packet_string() 

143 else: 

144 return lambda v: struct.pack(self.var_len + "s", bytes(v, "utf-8"))[: int(self.var_len)] 

145 

146 if self.var_type == "int": 

147 if self.var_len == "lenenc": 

148 return self.serialize_int 

149 else: 

150 return lambda v: struct.pack("Q", v)[: int(self.var_len)] 

151 

152 @classmethod 

153 def serialize_str_eof(cls, value): 

154 length = len(value) 

155 var_len = length 

156 if length == 0: 

157 return b"" 

158 else: 

159 return struct.pack("{len}s".format(len=var_len), bytes(value, "utf-8"))[:length] 

160 

161 # def serialize_obj(self, value): 

162 # return self.serialize_str(str(value)) 

163 

164 @classmethod 

165 def serialize_str(cls, value): 

166 return cls.serialize_bytes(value.encode("utf8")) 

167 

168 @classmethod 

169 def serialize_bytes(cls, value): 

170 val_len = len(value) 

171 

172 if val_len == 0: 

173 return b"\0" 

174 

175 if val_len < NULL_VALUE_INT: 

176 return struct.pack("B", val_len) + value 

177 

178 byte_count = -(val_len.bit_length() // (-8)) 

179 if byte_count <= 2: 

180 return TWO_BYTE_ENC + struct.pack("H", val_len) + value 

181 if byte_count <= 3: 

182 return THREE_BYTE_ENC + struct.pack("i", val_len)[:3] + value 

183 if byte_count <= 8: 

184 return THREE_BYTE_ENC + struct.pack("Q", val_len) + value 

185 

186 

187def test(): 

188 import pprint 

189 

190 u = Datum("int<8>", DEFAULT_CAPABILITIES >> 16) 

191 pprint.pprint(u.toStringPacket()) 

192 

193 

194# only run the test if this file is called from debugger 

195if __name__ == "__main__": 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true

196 test()