【Smol Course】1-指令调优
wbfwonderful Lv3

Chat template

Message and model

以下代码定义了一条消息,包括用户和大模型两个角色。

1
2
3
4
5
6
7
8
9
messages = [
{
"role": "user",
"content": "Hello, how are you?"},
{
"role": "assistant",
"content": "I'm doing well, thank you! How can I assist you today?",
},
]

模型包括两个部分:model(decoder)和 tokenizer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
import torch

device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

model_name = "D:\study\smol-course\data\SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_name
).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

注意这里的 setup_chat_format 方法,通过添加特殊的 token 来设置模型的聊天模板。

Apply chat template

使用 tokenizer 的 apply_chat_template 方法将定义的 message(list)转换为带特殊 token 的字符串:

1
2
3
input_text = tokenizer.apply_chat_template(messages, tokenize=False)

print("Conversation with template:\n", input_text)

结果为:

<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I’m doing well, thank you! How can I assist you today?<|im_end|>

如果将 tokenize 参数设置为 true,则会将 token 映射为 token 表中的 id。可以使用 tokenizer 的 decode 方法将 id 转换为 token。

1
2
3
4
5
input_text = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)

print("Conversation decoded:", tokenizer.decode(token_ids=input_text))

此外,apply_chat_template 还有一个参数 为 add_generation_prompt,作用是添加下一条消息的开头,即下面结果的最后一行 <|im_start|>assistant 。

Conversation decoded: <|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I’m doing well, thank you! How can I assist you today?<|im_end|>
<|im_start|>assistant

Process datatset for SFT

为现有的数据集添加聊天模板。

案例 1

首先加载数据集并打印

1
2
3
4
5
from datasets import load_dataset

ds = load_dataset("HuggingFaceTB/smoltalk", "everyday-conversations", cache_dir="D:\study\smol-course\data")

print(ds)

结果为如下,包含两个部分,分别为测试集和训练集。

DatasetDict({
     train: Dataset({
         features: ['full_topic', 'messages'],
         num_rows: 2260
     })
     test: Dataset({
         features: ['full_topic', 'messages'],
         num_rows: 119
     })
})

这里的 message 就是一个包含了用户和大模型之间对话的 list,内容如下,可以直接应用聊天模板。

[{‘content’: ‘Hi there’, ‘role’: ‘user’}, {‘content’: ‘Hello! How can I help you today?’, ‘role’: ‘assistant’}, {‘content’: “I’m looking for a beach resort for my next vacation. Can you recommend some popular ones?”, ‘role’: ‘user’}, {‘content’: “Some popular beach resorts include Maui in Hawaii, the Maldives, and the Bahamas. They’re known for their beautiful beaches and crystal-clear waters.”, ‘role’: ‘assistant’}, {‘content’: ‘That sounds great. Are there any resorts in the Caribbean that are good for families?’, ‘role’: ‘user’}, {‘content’: ‘Yes, the Turks and Caicos Islands and Barbados are excellent choices for family-friendly resorts in the Caribbean. They offer a range of activities and amenities suitable for all ages.’, ‘role’: ‘assistant’}, {‘content’: “Okay, I’ll look into those. Thanks for the recommendations!”, ‘role’: ‘user’}, {‘content’: “You’re welcome. I hope you find the perfect resort for your vacation.”, ‘role’: ‘assistant’}]

应用聊天模板的代码为:

1
2
3
4
5
6
7
8
9
def process_dataset(sample):
# TODO: 🐢 Convert the sample into a chat format
# use the tokenizer's method to apply the chat template
sample = tokenizer.apply_chat_template(sample["messages"], tokenize=False)

return {"chat": sample}

ds = ds.map(process_dataset)
print(ds)

上述代码定义一个 process_dataset 函数,用于处理数据集。map 方法对数据集中的每个样本都执行给定的函数。此外,map 是更新式映射,希望返回一个字典,添加到原有的数据集字典中。上述代码的结果为:

DatasetDict({
    train: Dataset({
        features: ['full_topic', 'messages', 'chat'],
        num_rows: 2260
    })
    test: Dataset({
        features: ['full_topic', 'messages', 'chat'],
        num_rows: 119
    })
})

可以看到,原有的数据集字典中新增了一个字段,内容为添加了聊天模板的字符串。

案例 2

本数据集包含了一个问题和相应的回答,所以在 process_dataset 函数中要先定义一个 message,然后再将 message 转换为带头聊天模板的字符串。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
ds = load_dataset("openai/gsm8k", "main", cache_dir="D:\study\smol-course\data")
print(ds)

def process_dataset(sample):
# TODO: 🐕 Convert the sample into a chat format

# 1. create a message format with the role and content
message = [
{'role': 'user', 'content': sample['question']},
{'role': 'assistant', 'content': sample['answer']}
]
# 2. apply the chat template to the samples using the tokenizer's method

sample = tokenizer.apply_chat_template(message, tokenize=False)

return {"chat": sample}


ds = ds.map(process_dataset)
print(ds)

结果为:

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 7473
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 1319
    })
})
DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'chat'],
        num_rows: 7473
    })
    test: Dataset({
        features: ['question', 'answer', 'chat'],
        num_rows: 1319
    })
})