1. 基本介绍

Hugging Face 的 datasets 库中的 load_dataset 方法是用于加载数据集的核心工具,它支持从多种来源(如本地文件、Hugging Face Hub、内存数据等)加载数据集,并返回标准的 DatasetDatasetDict 对象,方便进行高效的数据处理和训练。以下是详细介绍:


基本用法

from datasets import load_dataset

dataset = load_dataset(path, name=None, split=None, **kwargs)

主要参数说明

  1. path (必填):
    Hugging Face Hub 数据集: 直接传入 Hub 上的数据集名称,例如 "glue""squad""imdb" 等。
    本地文件/目录: 传入本地文件路径(支持 CSV/JSON/TXT 等格式),例如 "path/to/data.csv"
    自定义脚本: 传入本地数据集生成脚本的路径(需符合 datasets 库格式)。

  2. name (可选):
    • 指定数据集的子配置(例如 "glue" 数据集下有 "cola""sst2" 等子任务)。
    • 示例: load_dataset("glue", name="sst2") 加载 GLUE 的 SST-2 子任务。

  3. split (可选):
    • 指定加载的数据集划分,如 "train""test""validation",或组合(如 "train+test")。
    • 示例: split="train[:10%]" 加载训练集的前 10%。

  4. 其他常用参数:
    data_dir: 数据集文件的存储目录(适用于需要额外数据的场景)。
    data_files: 直接指定文件路径(支持通配符 *),例如 data_files={"train": "train.csv", "test": "test.csv"}
    cache_dir: 自定义缓存目录(默认在 ~/.cache/huggingface/datasets)。
    streaming: 设为 True 时启用流式加载(适用于超大数据集,无需全量加载到内存)。


常见使用场景

1. 从 Hugging Face Hub 加载数据集
# 加载 GLUE 数据集的 SST-2 子任务,训练集
dataset = load_dataset("glue", "sst2", split="train")

# 加载 SQuAD 问答数据集(返回 DatasetDict 包含 train 和 validation)
dataset_dict = load_dataset("squad")
2. 加载本地文件
# 加载 CSV 文件(自动推断格式)
dataset = load_dataset("csv", data_files="data.csv")

# 加载多个 JSON 文件
dataset = load_dataset("json", data_files={"train": "train.json", "test": "test.json"})
3. 流式模式(处理大型数据集)
# 逐样本加载,避免内存不足
streaming_dataset = load_dataset("big_dataset", streaming=True)
for example in iter(streaming_dataset["train"]):
    process(example)
4. 自定义数据集脚本

若数据集未在 Hub 上,可编写脚本定义数据加载逻辑(需符合 Dataset Script 规范):

dataset = load_dataset("path/to/dataset_script.py", name="my_config")

返回对象

Dataset: 单一切片的数据集(如 split="train"),支持类似 Pandas 的索引和操作。
DatasetDict: 包含多个切片的字典(例如 {"train": Dataset, "test": Dataset})。


数据处理特性

  1. 内存高效:基于 Apache Arrow 格式,支持零拷贝读取。
  2. 预处理:可通过 .map() 方法快速应用预处理函数。
  3. 兼容性:可轻松转换为 Pandas DataFrame(.to_pandas())或 NumPy 数组。

示例代码

# 加载 IMDb 电影评论数据集
dataset = load_dataset("imdb")

# 查看训练集前 3 条样本
print(dataset["train"][:3])

# 预处理:分词
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

注意事项

缓存机制:首次加载数据集时会下载或处理数据,后续调用直接读取缓存。
依赖安装:加载某些格式(如 Parquet)需安装额外依赖(pip install datasets[parquet])。
版本控制:通过 revision 参数指定数据集版本(如 Git 分支、commit hash)。

通过 load_dataset,Hugging Face 提供了统一且高效的接口,极大简化了 NLP 任务中的数据加载流程。

2. load_datase返回值

load_dataset 方法返回的数据类型取决于传入的参数,通常是 DatasetDatasetDict 对象。以下是具体规则:


1. 默认情况:返回 DatasetDict

当不指定 split 参数,且数据集包含预定义的多个划分(如 train/test/validation)时,返回一个 DatasetDict 对象。
示例

from datasets import load_dataset

# 加载 SQuAD 数据集(包含 train 和 validation)
dataset_dict = load_dataset("squad")
print(type(dataset_dict))  # 输出: <class 'datasets.dataset_dict.DatasetDict'>

# 访问训练集
train_data = dataset_dict["train"]

2. 指定 split 参数:返回 Dataset

当明确指定 split(如 split="train"),返回单个 Dataset 对象。
示例

# 仅加载 IMDb 数据集的测试集
dataset = load_dataset("imdb", split="test")
print(type(dataset))  # 输出: <class 'datasets.arrow_dataset.Dataset'>

# 访问第一条数据
print(dataset[0])  # 输出: {"text": "Great movie!", "label": 1, ...}

3. 自定义数据文件:返回类型灵活

通过 data_files 参数加载本地文件时:
• 如果指定单个文件(如 data_files="data.csv"),返回 Dataset
• 如果指定多个文件(如 data_files={"train": "train.csv", "test": "test.csv"}),返回 DatasetDict

示例

# 加载单个 CSV 文件
dataset = load_dataset("csv", data_files="data.csv", split="train")

# 加载多个 JSON 文件
dataset_dict = load_dataset("json", data_files={"train": "train.json", "test": "test.json"})

4. 流式模式 (streaming=True):返回迭代器

• 当启用流式加载(streaming=True)时,返回的是 IterableDatasetIterableDatasetDict,适合逐样本处理超大数据集。
示例

# 流式加载维基百科数据集
dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
print(type(dataset))  # 输出: <class 'datasets.iterable_dataset.IterableDataset'>

# 逐样本读取(无需全量加载到内存)
for example in dataset:
    print(example["text"])
    break  # 仅读取第一条

关键区别

返回类型适用场景数据访问方式
DatasetDict多划分数据集(如 train/test)通过字典键访问:dataset_dict["train"]
Dataset单划分数据集直接索引或切片:dataset[0]
IterableDataset(Dict)流式处理超大数据集(避免内存爆炸)通过迭代器逐条访问:for example in dataset: ...

常用操作

转换为其他格式:
# 转 Pandas DataFrame
df = dataset.to_pandas()

# 转 Python 字典列表
data_list = dataset.to_list()
数据预处理:
# 使用 .map() 方法批量处理
def preprocess(example):
    example["text_length"] = len(example["text"])
    return example

processed_dataset = dataset.map(preprocess)
过滤数据:
# 过滤短文本
filtered_dataset = dataset.filter(lambda x: len(x["text"]) > 100)

注意事项

  1. 缓存机制:首次加载数据集会下载或处理数据,之后直接从缓存读取(路径默认在 ~/.cache/huggingface/datasets)。
  2. 内存优化Dataset 基于 Apache Arrow 格式,内存占用低,支持快速随机访问。
  3. 版本控制:通过 revision 参数加载特定版本的数据集(如 Git 分支、commit hash)。

简单总结

• 需要处理多个划分(如训练集、测试集) → DatasetDict
• 只需单个划分 → Dataset
• 处理超大数据集 → IterableDataset + 流式模式

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐