Coverage for mindsdb / integrations / libs / llm / utils.py: 71%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict, List, Tuple 

2import json 

3import itertools 

4import re 

5 

6import numpy as np 

7import pandas as pd 

8 

9from mindsdb.integrations.libs.llm.config import ( 

10 AnthropicConfig, 

11 BaseLLMConfig, 

12 GoogleConfig, 

13 LiteLLMConfig, 

14 OllamaConfig, 

15 OpenAIConfig, 

16 NvidiaNIMConfig, 

17 MindsdbConfig, 

18 WriterConfig, 

19 BedrockConfig, 

20) 

21from mindsdb.utilities.config import config 

22from langchain_text_splitters import Language, RecursiveCharacterTextSplitter 

23 

24 

25# Default to latest GPT-4 model (https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo) 

26DEFAULT_OPENAI_MODEL = "gpt-4o" 

27# Requires more than vanilla OpenAI due to ongoing summarization and 3rd party input. 

28DEFAULT_OPENAI_MAX_TOKENS = 8096 

29DEFAULT_OPENAI_MAX_RETRIES = 3 

30 

31DEFAULT_ANTHROPIC_MODEL = "claude-3-haiku-20240307" 

32 

33DEFAULT_GOOGLE_MODEL = "gemini-2.5-pro-preview-03-25" 

34 

35DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo" 

36DEFAULT_LITELLM_PROVIDER = "openai" 

37DEFAULT_LITELLM_BASE_URL = "https://ai.dev.mindsdb.com" 

38 

39DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434" 

40DEFAULT_OLLAMA_MODEL = "llama2" 

41 

42DEFAULT_NVIDIA_NIM_BASE_URL = "http://localhost:8000/v1" # Assumes local port forwarding through ssh 

43DEFAULT_NVIDIA_NIM_MODEL = "meta/llama-3_1-8b-instruct" 

44DEFAULT_VLLM_SERVER_URL = "http://localhost:8000/v1" 

45 

46 

47def get_completed_prompts(base_template: str, df: pd.DataFrame, strict=True) -> Tuple[List[str], np.ndarray]: 

48 """ 

49 Helper method that produces formatted prompts given a template and data in a Pandas DataFrame. 

50 It also returns the ID of any empty templates that failed to be filled due to missing data. 

51 

52 :param base_template: string with placeholders for each column in the DataFrame. Placeholders should follow double curly braces format, e.g. `{{column_name}}`. All placeholders should have matching columns in `df`. 

53 :param df: pd.DataFrame to generate full prompts. Each placeholder in `base_template` must exist as a column in the DataFrame. If a column is not in the template, it is ignored entirely. 

54 :param strict: raise exception if base_template doesn't contain placeholders 

55 

56 :return prompts: list of in-filled prompts using `base_template` and relevant columns from `df` 

57 :return empty_prompt_ids: np.int numpy array (shape (n_missing_rows,)) with the row indexes where in-fill failed due to missing data. 

58 """ # noqa 

59 columns = [] 

60 spans = [] 

61 matches = list(re.finditer("{{(.*?)}}", base_template)) 

62 

63 if len(matches) == 0: 

64 # no placeholders 

65 if strict: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was always true

66 raise AssertionError("No placeholders found in the prompt, please provide a valid prompt template.") 

67 prompts = [base_template] * len(df) 

68 return prompts, np.ndarray(0) 

69 

70 first_span = matches[0].start() 

71 last_span = matches[-1].end() 

72 

73 for m in matches: 

74 columns.append(m[0].replace("{", "").replace("}", "")) 

75 spans.extend((m.start(), m.end())) 

76 

77 spans = spans[1:-1] # omit first and last, they are added separately 

78 template = [ 

79 base_template[s:e] for s, e in list(zip(spans, spans[1:]))[::2] 

80 ] # take every other to skip placeholders # noqa 

81 template.insert(0, base_template[0:first_span]) # add prompt start 

82 template.append(base_template[last_span:]) # add prompt end 

83 

84 empty_prompt_ids = np.where(df[columns].isna().all(axis=1).values)[0] 

85 

86 df["__mdb_prompt"] = "" 

87 for i in range(len(template)): 

88 atom = template[i] 

89 if i < len(columns): 

90 col = df[columns[i]].replace(to_replace=[None], value="") # add empty quote if data is missing 

91 df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) + col.astype("string") 

92 else: 

93 df["__mdb_prompt"] = df["__mdb_prompt"].apply(lambda x: x + atom) 

94 prompts = list(df["__mdb_prompt"]) 

95 

96 return prompts, empty_prompt_ids 

97 

98 

99def get_llm_config(provider: str, args: Dict) -> BaseLLMConfig: 

100 """ 

101 Helper method that returns the configuration for a given LLM provider. 

102 

103 :param provider: string with the name of the provider. 

104 :param config: dictionary with the configuration for the provider. 

105 

106 :return: LLMConfig object with the configuration for the provider. 

107 """ 

108 temperature = min(1.0, max(0.0, args.get("temperature", 0.0))) 

109 if provider == "openai": 

110 if any(x in args.get("model_name", "") for x in ["o1", "o3"]): 

111 # for o1 and 03, 'temperature' does not support 0.0 with this model. Only the default (1) value is supported 

112 temperature = 1 

113 

114 return OpenAIConfig( 

115 model_name=args.get("model_name", DEFAULT_OPENAI_MODEL), 

116 temperature=temperature, 

117 max_retries=args.get("max_retries", DEFAULT_OPENAI_MAX_RETRIES), 

118 max_tokens=args.get("max_tokens", DEFAULT_OPENAI_MAX_TOKENS), 

119 openai_api_base=args.get("base_url", None), 

120 openai_api_key=args["api_keys"].get("openai", None), 

121 openai_organization=args.get("api_organization", None), 

122 request_timeout=args.get("request_timeout", None), 

123 ) 

124 if provider == "anthropic": 

125 return AnthropicConfig( 

126 model=args.get("model_name", DEFAULT_ANTHROPIC_MODEL), 

127 temperature=temperature, 

128 max_tokens=args.get("max_tokens", None), 

129 top_p=args.get("top_p", None), 

130 top_k=args.get("top_k", None), 

131 default_request_timeout=args.get("default_request_timeout", None), 

132 anthropic_api_key=args["api_keys"].get("anthropic", None), 

133 anthropic_api_url=args.get("base_url", None), 

134 ) 

135 if provider == "litellm": 

136 model_kwargs = { 

137 "api_key": args["api_keys"].get("litellm", None), 

138 "top_p": args.get("top_p", None), 

139 "request_timeout": args.get("request_timeout", None), 

140 "frequency_penalty": args.get("frequency_penalty", None), 

141 "presence_penalty": args.get("presence_penalty", None), 

142 "logit_bias": args.get("logit_bias", None), 

143 } 

144 return LiteLLMConfig( 

145 model=args.get("model_name", DEFAULT_LITELLM_MODEL), 

146 temperature=temperature, 

147 api_base=args.get("base_url", DEFAULT_LITELLM_BASE_URL), 

148 max_retries=args.get("max_retries", DEFAULT_OPENAI_MAX_RETRIES), 

149 max_tokens=args.get("max_tokens", DEFAULT_OPENAI_MAX_TOKENS), 

150 top_p=args.get("top_p", None), 

151 top_k=args.get("top_k", None), 

152 custom_llm_provider=args.get("custom_llm_provider", DEFAULT_LITELLM_PROVIDER), 

153 model_kwargs=model_kwargs, 

154 ) 

155 if provider == "ollama": 

156 return OllamaConfig( 

157 base_url=args.get("base_url", DEFAULT_OLLAMA_BASE_URL), 

158 model=args.get("model_name", DEFAULT_OLLAMA_MODEL), 

159 temperature=temperature, 

160 top_p=args.get("top_p", None), 

161 top_k=args.get("top_k", None), 

162 timeout=args.get("request_timeout", None), 

163 format=args.get("format", None), 

164 headers=args.get("headers", None), 

165 num_predict=args.get("num_predict", None), 

166 num_ctx=args.get("num_ctx", None), 

167 num_gpu=args.get("num_gpu", None), 

168 repeat_penalty=args.get("repeat_penalty", None), 

169 stop=args.get("stop", None), 

170 template=args.get("template", None), 

171 ) 

172 if provider == "nvidia_nim": 

173 return NvidiaNIMConfig( 

174 base_url=args.get("base_url", DEFAULT_NVIDIA_NIM_BASE_URL), 

175 model=args.get("model_name", DEFAULT_NVIDIA_NIM_MODEL), 

176 temperature=temperature, 

177 top_p=args.get("top_p", None), 

178 timeout=args.get("request_timeout", None), 

179 format=args.get("format", None), 

180 headers=args.get("headers", None), 

181 num_predict=args.get("num_predict", None), 

182 num_ctx=args.get("num_ctx", None), 

183 num_gpu=args.get("num_gpu", None), 

184 repeat_penalty=args.get("repeat_penalty", None), 

185 stop=args.get("stop", None), 

186 template=args.get("template", None), 

187 nvidia_api_key=args["api_keys"].get("nvidia_nim", None), 

188 ) 

189 if provider == "mindsdb": 

190 return MindsdbConfig( 

191 model_name=args["model_name"], 

192 project_name=args.get("project_name", config.get("default_project")), 

193 ) 

194 if provider == "vllm": 

195 return OpenAIConfig( 

196 model_name=args.get("model_name"), 

197 temperature=temperature, 

198 max_retries=args.get("max_retries", DEFAULT_OPENAI_MAX_RETRIES), 

199 max_tokens=args.get("max_tokens", DEFAULT_OPENAI_MAX_TOKENS), 

200 openai_api_base=args.get("base_url", DEFAULT_VLLM_SERVER_URL), 

201 openai_api_key=args["api_keys"].get("vllm", "EMPTY`"), 

202 openai_organization=args.get("api_organization", None), 

203 request_timeout=args.get("request_timeout", None), 

204 ) 

205 if provider == "google": 

206 return GoogleConfig( 

207 model=args.get("model_name", DEFAULT_GOOGLE_MODEL), 

208 temperature=temperature, 

209 top_p=args.get("top_p", None), 

210 top_k=args.get("top_k", None), 

211 max_output_tokens=args.get("max_tokens", None), 

212 google_api_key=args["api_keys"].get("google", None), 

213 ) 

214 if provider == "writer": 

215 return WriterConfig( 

216 model_name=args.get("model_name", "palmyra-x5"), 

217 temperature=temperature, 

218 max_tokens=args.get("max_tokens", None), 

219 top_p=args.get("top_p", None), 

220 stop=args.get("stop", None), 

221 best_of=args.get("best_of", None), 

222 writer_api_key=args["api_keys"].get("writer", None), 

223 writer_org_id=args.get("writer_org_id", None), 

224 base_url=args.get("base_url", None), 

225 ) 

226 if provider == "bedrock": 

227 return BedrockConfig( 

228 model_id=args.get("model_name"), 

229 temperature=temperature, 

230 max_tokens=args.get("max_tokens", None), 

231 stop=args.get("stop", None), 

232 base_url=args.get("endpoint_url", None), 

233 aws_access_key_id=args.get("aws_access_key_id", None), 

234 aws_secret_access_key=args.get("aws_secret_access_key", None), 

235 aws_session_token=args.get("aws_session_token", None), 

236 region_name=args.get("aws_region_name", None), 

237 credentials_profile_name=args.get("credentials_profile_name", None), 

238 model_kwargs=args.get("model_kwargs", None), 

239 ) 

240 

241 raise ValueError(f"Provider {provider} is not supported.") 

242 

243 

244def ft_jsonl_validation( 

245 items: list, # read from a JSONL file 

246 messages_col: str = "messages", 

247 # valid keys for each chat message 

248 role_key: str = "role", 

249 content_key: str = "content", 

250 name_key: str = "name", 

251 # valid roles for each chat message 

252 system_key: str = "system", 

253 user_key: str = "user", 

254 assistant_key: str = "assistant", 

255): 

256 """ 

257 This helper checks a list of dictionaries for compliance with the format usually expected by LLM providers 

258 (such as OpenAI or AnyscaleEndpoints) for fine-tuning LLMs that generate chat completions. 

259 

260 Defaults are set according to the expected format, but these can be changed if needed by any given provider. 

261 

262 :param items: list of JSON lines, each dictionary containing a chat sequence. Should be read from a JSONL file. 

263 :param messages_col: key in each dictionary to access a sequence of chat messages 

264 

265 

266 For chat-level checks, this method defers to `ft_chat_format_validation()` below. Relevant parameters for it are: 

267 

268 For each chat: 

269 :param role_key: key that defines the role of each message (e.g. system, user, or LLM) 

270 :param content_key: key that defines the content of each message 

271 :param name_key: key that defines the name of each message 

272 

273 For each message: 

274 :param system_key: valid role for each chat message 

275 :param user_key: valid role for each chat message 

276 :param assistant_key: valid role for each chat message 

277 

278 :return: None, raises an Exception if validation fails. 

279 """ # noqa 

280 try: 

281 if not all([isinstance(m, dict) for m in items]): 281 ↛ 282line 281 didn't jump to line 282 because the condition on line 281 was never true

282 raise Exception("Each line in the provided data should be a dictionary") 

283 

284 for line_num, batch in enumerate(items): 

285 prefix = f"error in chat #{line_num + 1}, " 

286 

287 if not isinstance(batch[messages_col], list): 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 raise Exception( 

289 f"{prefix}Each line in the provided data should have a '{messages_col}' key with a list of messages" 

290 ) # noqa 

291 

292 if messages_col not in batch: 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true

293 raise Exception(f"{prefix}Each line in the provided data should have a '{messages_col}' key") 

294 

295 messages = batch[messages_col] 

296 try: 

297 ft_chat_format_validation( 

298 messages, 

299 role_key=role_key, 

300 content_key=content_key, 

301 name_key=name_key, 

302 system_key=system_key, 

303 user_key=user_key, 

304 assistant_key=assistant_key, 

305 ) 

306 except Exception as e: 

307 raise Exception(f"{prefix}{e}") from e 

308 

309 except Exception as e: 

310 raise Exception(f"Fine-tuning data format is not valid. Got {e}") from e 

311 

312 

313def ft_chat_format_validation( 

314 chat: list, 

315 transitions: Optional[Dict] = None, 

316 system_key: str = "system", 

317 user_key: str = "user", 

318 assistant_key: str = "assistant", 

319 role_key: str = "role", 

320 content_key: str = "content", 

321 name_key: str = "name", 

322): 

323 """ 

324 Finite state machine to check a chat has valid format to finetune an LLM with it. 

325 Follows OpenAI ChatCompletion format (also used by other providers such as AnyscaleEndpoints). 

326 Reference: https://cookbook.openai.com/examples/chat_finetuning_data_prep 

327 

328 The unit test in `test_llm_utils.py` for examples of valid and invalid chats. 

329 

330 :param chat: list of dictionaries, each containing a chat message 

331 :param transitions: optional dictionary defining valid transitions between chat messages (e.g. from system to user to assistant) 

332 

333 For each chat: 

334 :param role_key: key that defines the role of each message (e.g. system, user, or LLM) 

335 :param content_key: key that defines the content of each message 

336 :param name_key: key that defines the name of each message 

337 

338 For each message: 

339 :param system_key: valid role for each chat message 

340 :param user_key: valid role for each chat message 

341 :param assistant_key: valid role for each chat message 

342 

343 :return: None if chat is valid, otherwise raise an informative Exception. 

344 """ # noqa 

345 

346 valid_keys = (role_key, content_key, name_key) 

347 valid_roles = (system_key, user_key, assistant_key) 

348 

349 for c in chat: 

350 if any(k not in valid_keys for k in c.keys()): 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 raise Exception(f"Each message should only have these keys: `{valid_keys}`. Found: `{c.keys()}`") 

352 

353 roles = [m[role_key] for m in chat] 

354 contents = [m[content_key] for m in chat] 

355 

356 if len(roles) != len(contents): 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 raise Exception(f"Each message should contain both `{role_key}` and `{content_key}` fields") 

358 

359 if len(roles) == 0: 359 ↛ 360line 359 didn't jump to line 360 because the condition on line 359 was never true

360 raise Exception("Chat should have at least one message") 

361 

362 if assistant_key not in roles: 

363 raise Exception("Chat should have at least one assistant message") # otherwise it is useless for FT 

364 

365 if user_key not in roles: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 raise Exception("Chat should have at least one user message") # perhaps remove in the future 

367 

368 # set default transitions for finite state machine if undefined 

369 if transitions is None: 369 ↛ 378line 369 didn't jump to line 378 because the condition on line 369 was always true

370 transitions = { 

371 None: [system_key, user_key], 

372 system_key: [user_key], 

373 user_key: [assistant_key], 

374 assistant_key: [user_key], 

375 } 

376 

377 # check order is valid via finite state machine 

378 state = None 

379 for i, (role, content) in enumerate(zip(roles, contents)): 

380 prefix = f"message #{i + 1}: " 

381 

382 # check invalid roles 

383 if role not in valid_roles: 

384 raise Exception(f"{prefix}Invalid role (found `{role}`, expected one of `{valid_roles}`)") 

385 

386 # check content 

387 if not isinstance(content, str): 

388 raise Exception(f"{prefix}Content should be a string, got type `{type(content)}`") 

389 

390 # check transition 

391 if role not in transitions[state]: 

392 raise Exception(f"{prefix}Invalid transition from `{state}` to `{role}`") 

393 else: 

394 state = role 

395 

396 

397def ft_formatter(df: pd.DataFrame) -> List[Dict]: 

398 """ 

399 Data preparation entry point for chat LLM finetuning. This method will dispatch to the appropriate formatters. 

400 

401 Supported formats: 

402 - code: long tabular format with a `code` column 

403 - chat: long tabular format with `role` and `content` columns, or a JSON format with a `chat_json` column. 

404 """ 

405 if "code" in df.columns: 

406 df = ft_code_formatter(df) 

407 

408 elif {"question", "context", "answer"}.issubset(set(df.columns)): 

409 # TODO: handler user-specified names for these columns 

410 df = ft_cqa_formatter(df) 

411 

412 return ft_chat_formatter(df) 

413 

414 

415def ft_chat_formatter(df: pd.DataFrame) -> List[Dict]: 

416 """ 

417 For more details, check `FineTuning -> Data Format` in the Anyscale API reference, or the OpenAI equivalent. 

418 Additionally, the unit test in `test_llm_utils.py` provides example usage. 

419 

420 :param df: input dataframe has chats in one of the following formats: 

421 1) long tabular: at least two columns, `role` and `content`. Rows contain >= 1 chats in long (stacked) format. 

422 

423 2) JSON: at least one column, `chat_json`. Each row contains exactly 1 chat in JSON format. 

424 Example for `chat_json` content: 

425 > `{"messages": [{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi!"}]}` 

426 

427 Optional df columns are: 

428 - chat_id: unique identifier for each chat 

429 - message_id: unique identifier for each message within each chat 

430 

431 Data will be sorted by both if they are provided. 

432 

433 If only `chat_id` is provided, data will be sorted by it with a stable sort, so messages for each chat 

434 will be in the same order as in the original data. 

435 

436 If only `message_id` is provided, it must not contain duplicate IDs. Entire dataset will be treated 

437 as a single chat. Otherwise an exception will be raised. 

438 

439 :return: list of chats. Each chat is a dictionary with a top level key 'messages' containing a list of messages 

440 that comply with the OpenAI's ChatEndpoint expected format (i.e., each is a dictionary with a `role` and 

441 `content` key. 

442 

443 """ # noqa 

444 # 1. pre-sort df on optional columns 

445 if "chat_id" in df.columns: 

446 if "message_id" in df.columns: 

447 df = df.sort_values(["chat_id", "message_id"]) 

448 else: 

449 df = df.sort_values(["chat_id"], kind="stable") 

450 elif "message_id" in df.columns: 450 ↛ 451line 450 didn't jump to line 451 because the condition on line 450 was never true

451 if df["message_id"].duplicated().any(): 

452 raise Exception("If `message_id` is provided, it must not contain duplicate IDs.") 

453 df = df.sort_values(["message_id"]) 

454 

455 # 2. build chats 

456 chats = [] 

457 

458 # 2a. chats are in JSON format 

459 if "chat_json" in df.columns: 

460 for _, row in df.iterrows(): 

461 try: 

462 chat = json.loads(row["chat_json"]) 

463 assert list(chat.keys()) == ["messages"], "Each chat should have a 'messages' key, and nothing else." 

464 ft_chat_format_validation(chat["messages"]) # will raise Exception if chat is invalid 

465 chats.append(chat) 

466 except json.JSONDecodeError: 

467 pass # TODO: add logger info here, prompt user to clean dataset carefully 

468 

469 # 2b. chats are in tabular format - aggregate each chat sequence into one row 

470 else: 

471 chat = [] 

472 for i, row in df.iterrows(): 

473 if row["role"] == "system" and len(chat) > 0: 

474 ft_chat_format_validation(chat) # will raise Exception if chat is invalid 

475 chats.append({"messages": chat}) 

476 chat = [] 

477 event = {"role": row["role"], "content": row["content"]} 

478 chat.append(event) 

479 

480 ft_chat_format_validation(chat) # will raise Exception if chat is invalid 

481 chats.append({"messages": chat}) 

482 

483 return chats 

484 

485 

486def ft_code_formatter( 

487 df: pd.DataFrame, 

488 format="chat", 

489 language="python", 

490 chunk_size=100, 

491 chunk_overlap=0, 

492 chat_sections=("Code prefix", "Code suffix", "Completion"), 

493 fim_tokens=("<PRE>", "<SUF>", "<MID>"), 

494) -> pd.DataFrame: 

495 """ 

496 This utility processes a raw codebase stored as a dataframe with a `code` column, where 

497 every row may be an entire file or some portion of it. 

498 It chunks code into triples made of a prefix, middle, and suffix. 

499 

500 Depending on the target LLM, these triples are then formatted into a chat-like prompt, or a 

501 fill-in-the-middle (FIM) prompt. The latter is used for fine-tuning models like codellama, 

502 while the former is more generic and should work with any LLM that supports the ChatCompletion 

503 format, as the rest of our tools do. 

504 """ 

505 

506 # input and setup validation 

507 assert len(df) > 0, "Input dataframe should not be empty" 

508 assert "code" in df.columns, "Input dataframe should have a 'code' column" 

509 assert chunk_size > 0 and isinstance(chunk_size, int), "`chunk_size` should be a positive integer" 

510 

511 supported_formats = ["chat", "fim"] 

512 supported_langs = [e.value for e in Language] 

513 assert language.lower() in supported_langs, f"Invalid language. Valid choices are: {supported_langs}" 

514 

515 # ensure correct encoding 

516 df["code"] = df["code"].map(lambda x: x.encode("utf8").decode("unicode_escape")) 

517 

518 # set prompt templates 

519 system_prompt = "You are a powerful text to code model. Your job is to provide great code completions. As context, you are given code that is found immediately before and after the code you must generate.\n\nYou must output the code that should go in between the prefix and suffix.\n\n" 

520 if format == "chat": 

521 templates = [f"### {c}:" for c in chat_sections] 

522 elif format == "fim": 522 ↛ 525line 522 didn't jump to line 525 because the condition on line 522 was always true

523 templates = fim_tokens 

524 else: 

525 raise Exception(f"Invalid format. Please choose one of {supported_formats}") 

526 

527 # split code into chunks 

528 code_splitter = RecursiveCharacterTextSplitter.from_language( 

529 language=getattr(Language, language.upper()), 

530 chunk_size=3 * chunk_size, # each triplet element has `chunk_size` 

531 chunk_overlap=chunk_overlap, # some overlap here is fine 

532 ) 

533 chunk_docs = code_splitter.create_documents(list(df["code"])) 

534 chunks = [c.page_content for c in chunk_docs] 

535 

536 # split each chunk into a triplet, with no overlap 

537 triplet_splitter = RecursiveCharacterTextSplitter.from_language( 

538 language=getattr(Language, language.upper()), 

539 chunk_size=chunk_size, 

540 chunk_overlap=0, # no overlap admitted, otherwise context may leak into answer 

541 ) 

542 triplet_chunk_docs = triplet_splitter.create_documents(chunks) 

543 chunks = [c.page_content for c in triplet_chunk_docs] 

544 chunks = chunks[: len(chunks) - len(chunks) % 3] # should be a multiple of 3 

545 

546 # format chunks into prompts 

547 roles = [] 

548 contents = [] 

549 for idx in range(0, len(chunks), 3): 

550 pre, mid, suf = chunks[idx : idx + 3] 

551 interleaved = list(itertools.chain(*zip(templates, (pre, suf, mid)))) 

552 user = "\n".join(interleaved[:-1]) 

553 assistant = "\n".join(interleaved[-1:]) 

554 roles.extend(["system", "user", "assistant"]) 

555 contents.extend([system_prompt, user, assistant]) 

556 

557 # return formatted prompts in a dataframe to be processed by `ft_chat_formatter()` 

558 df = pd.DataFrame({"role": roles, "content": contents}) 

559 return df 

560 

561 

562def ft_cqa_formatter( 

563 df: pd.DataFrame, 

564 question_col="question", 

565 answer_col="answer", 

566 instruction_col="instruction", 

567 context_col="context", 

568 default_instruction="You are a helpful assistant.", 

569 default_context="", 

570) -> pd.DataFrame: 

571 # input and setup validation 

572 assert len(df) > 0, "Input dataframe should not be empty" 

573 assert {question_col, answer_col}.issubset(set(df.columns)), ( 

574 f"Input dataframe must have columns `{question_col}`, and `{answer_col}`" 

575 ) # noqa 

576 

577 if instruction_col not in df.columns: 577 ↛ 578line 577 didn't jump to line 578 because the condition on line 577 was never true

578 df[instruction_col] = default_instruction 

579 

580 if context_col not in df.columns: 580 ↛ 581line 580 didn't jump to line 581 because the condition on line 580 was never true

581 df[context_col] = default_context 

582 

583 # format data into chat-like prompts 

584 roles = [] 

585 contents = [] 

586 for i, row in df.iterrows(): 

587 system = "\n".join([row[instruction_col], row[context_col]]) 

588 user = row[question_col] 

589 assistant = row[answer_col] 

590 roles.extend(["system", "user", "assistant"]) 

591 contents.extend([system, user, assistant]) 

592 

593 return pd.DataFrame({"role": roles, "content": contents})