<div align="center">
**基于 [`pytorch-lightning`](https://github.com/Lightning-AI/lightning) 和 [`transformers`](https://github.com/huggingface/transformers) 框架实现各类 `NLP` 任务**
</div>
## 🔨 安装
1. 搭建好 `pytorch GPU` 深度学习环境
```bash
conda create -n pytorch python=3.8
conda activate pytorch
conda install pytorch cudatoolkit -c pytorch
```
2. 安装 `lightningnlp`
```bash
pip install lightningnlp
```
3. 在 `https://pytorch-geometric.com/whl/` 中找到与 `torch` 版本对应的 `torch_scatter`,下载后使用 `pip` 安装到环境中
```python
import torch
print(torch.__version__) # 1.12.0
print(torch.version.cuda) # 11.3
```
```bash
# 以python=3.8, torch=1.12.0, cuda=11.3为例
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp38-cp38-linux_x86_64.whl
pip install torch_scatter-2.1.0+pt112cu113-cp38-cp38-linux_x86_64.whl
```
本项目也提供了[docker安装方式](./docker)
## 🧾 文本分类
### 1. 数据格式
<details>
<summary>训练数据示例</summary>
```json
{
"text": "以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击",
"label": "news_military"
}
```
</details>
### 2. 模型
| 模型 | 论文 | 备注 |
|-----------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|---------------------------------|
| [fc](lightningnlp/task/text_classification/fc/model.py) | | 全连接层分类 |
| [mdp](lightningnlp/task/text_classification/mdp/model.py) | [Multi-Sample Dropout for Accelerated Training and Better Generalization.](https://arxiv.org/abs/1905.09788) | 使用 `MultiSampleDropout`,类似于模型融合 |
<details>
<summary>训练代码示例</summary>
```python
import os
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger # 需要安装wandb
from transformers import BertTokenizerFast
from lightningnlp.callbacks import LoggingCallback
from lightningnlp.task.text_classification import (
TextClassificationDataModule,
TextClassificationTransformer,
)
pl.seed_everything(seed=42)
pretrained_model_name_or_path = "hfl/chinese-roberta-wwm-ext" # 预训练模型
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
dm = TextClassificationDataModule(
tokenizer=tokenizer,
train_batch_size=16, # 训练集batch_size
validation_batch_size=16, # 验证集batch_size
num_workers=16, # 多进程加载数据
dataset_name="datasets/sentiment", # 训练数据所在目录
train_file="train.json", # 训练集文件名
validation_file="dev.json", # 验证集文件名
train_max_length=256,
cache_dir="datasets/sentiment", # 数据缓存路径
)
model = TextClassificationTransformer(
downstream_model_name="fc", # 模型名称
downstream_model_type="bert", # 预训练模型类型
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer=tokenizer,
label_map=dm.label_map,
learning_rate=2e-5,
output_dir="outputs/sentiment/fc", # 模型保存路径
)
model_ckpt = pl.callbacks.ModelCheckpoint(
dirpath="outputs/sentiment/fc",
filename="best_model",
monitor="val_accuracy",
save_top_k=1,
mode="max",
)
wandb_logger = WandbLogger(project="Text Classification", name="fc")
trainer = pl.Trainer(
logger=wandb_logger,
accelerator="gpu",
devices=1,
max_epochs=12,
val_check_interval=0.5,
gradient_clip_val=1.0,
callbacks=[model_ckpt, LoggingCallback()]
)
trainer.fit(model, dm)
```
</details>
### 3. 预测
```python
from lightningnlp.task.text_classification import TextClassificationTransformer
model = TextClassificationTransformer.load_from_checkpoint("my_bert_model_path")
text = "以色列大规模空袭开始!伊朗多个军事目标遭遇打击,誓言对等反击"
print(model.predict(text))
```
## 📄 命名实体识别
### 1. 数据格式
<details>
<summary>训练数据示例</summary>
```json
{
"text": "结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。",
"entities": [
{
"id": 0,
"entity": "瓦拉多利德",
"start_offset": 20,
"end_offset": 25,
"label": "organization"
},
{
"id": 1,
"entity": "西甲",
"start_offset": 33,
"end_offset": 35,
"label": "organization"
}
]
}
```
</details>
### 2. 模型
| 模型 | 论文 | 备注 |
|--------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| [softmax](lightningnlp/task/named_entity_recognition/crf/model.py) | | 全连接层序列标注并使用 `BIO` 解码 |
| [crf](lightningnlp/task/named_entity_recognition/crf/model.py) | [Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers) | 全连接层+条件随机场,并使用 `BIO` 解码 |
| [cascade-crf](lightningnlp/task/named_entity_recognition/crf/model.py) | | 先预测实体再预测实体类型 |
| [span](lightningnlp/task/named_entity_recognition/span/model.py) | | 使用两个指针网络预测实体起始位置 |
| [global-pointer](lightningnlp/task/named_entity_recognition/global_pointer/model.py) | | [GlobalPointer:用统一的方式处理嵌套和非嵌套NER](https://spaces.ac.cn/archives/8373)、[Efficient GlobalPointer:少点参数,多点效果](https://spaces.ac.cn/archives/8877) |
| [mrc](lightningnlp/task/named_entity_recognition/mrc/model.py) | [A Unified MRC Framework for Named Entity Recognition.](https://aclanthology.org/2020.acl-main.519.pdf) | 将实体识别任务转换为阅读理解问题,输入为实体类型模板+句子,预测对应实体的起始位置 |
| [tplinker](lightningnlp/task/named_entity_recognition/tplinker/model.py) | [TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking.](https://aclanthology.org/2020.coling-main.138.pdf) | 将实体识别任务转换为表格填充问题 |
| [lear](lightningnlp/task/named_entity_recognition/lear/model.py) | [Enhanced Language Representation with Label Knowledge for Span Extraction.](https://aclanthology.org/2021.emnlp-main.379.pdf) | 改进 `MRC` 方法效率问题,采用标签融合机制 |
| [w2ner](lightningnlp/task/named_entity_recognition/w2ner/model.py) | [Unified Named Entity Recognition as Word-Word Relation Classification.](https://arxiv.org/pdf/2112.10070.pdf) | 统一解决嵌套实体、不连续实体的抽取问题 |
| [cnn](lightningnlp/task/named_entity_recognition/cnn/model.py) | [An Embarrassingly Easy but Strong Baseline for Nested Named Entity Recognition.](https://arxiv.org/abs/2208.04534) | 改进 `W2NER` 方法,采用卷积网络提取实体内部token之间的关系 |
<details>
<summary>训练代码示例</summary>
```python
import os
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from transformers import BertTokenizerFast
from lightningnlp.callbacks import LoggingCallback
from lightningnlp.task.named_entity_recognition import (
CRFNerDataModule,
NamedEntityRecognitionTransformer,
)
pl.seed_everything(seed=42)
pretrained_model_name_or_path = "hfl/chinese-roberta-wwm-ext" # 预训练模型
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
dm = CRFNerDataModule(
tokenizer=tokenizer,
train_batch_size=16, # 训练集batch_size
validation_batch_size=16, # 验证集batch_size
num_workers=16, # 多进程加载数据
dataset_name="xusenlin/cmeee", # huggingface数据集
train_max_length=256,
validation_max_length=256,
cache_dir="datasets/cmeee", # 数据缓存路径
task_name="cmeee-bert-crf", # 自定义任务名称
is_chinese=True,
)
model = NamedEntityRecognitionTransformer(
downstream_model_name="crf", # 模型名称
downstream_model_type="bert", # 预训练模型类型
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer=tokenizer,
labels=dm.label_list,
learning_rate=2e-5,
average="micro",
output_dir="outputs/cmeee/crf", # 模型保存路径
)
model_ckpt = pl.callbacks.ModelCheckpoint(
dirpath="outputs/cmeee/crf",
filename="best_model",
monitor="val_f1_micro",
save_top_k=1,
mode="max",
)
wandb_logger = WandbLogger(project="Named Entity Recognition", name="cmeee-bert-crf")
trainer = pl.Trainer(
logger=wandb_logger,
accelerator="gpu",
devices=1,
max_epochs=12,
val_check_interval=0.5,
gradient_clip_val=1.0,
callbacks=[model_ckpt, LoggingCallback()]
)
trainer.fit(model, dm)
```
</details>
### 3. 预测
本项目在 [huggingface](https://huggingface.co/xusenlin/cmeee-global-pointer) 上提供了一个训练好的模型作为示例可供测试和使用,运行以下代码会自动下载模型并进行预测
```python
from pprint import pprint
from lightningnlp.task.named_entity_recognition import NerPipeline
pipline = NerPipeline(model_name_or_path="xusenlin/cmeee-global-pointer", model_name="global-pointer", model_type="bert")
text = "结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。"
pprint(pipline(text))
```
### 4. APP应用
![ner](./images/ner.png)
## 🔖 实体关系抽取
### 1. 数据格式
<details>
<summary>训练数据示例</summary>
```json
{
"text": "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部",
"spo_list": [
{
"predicate": "出生地",
"object_type": "地点",
"subject_type": "人物",
"object": "圣地亚哥",
"subject": "查尔斯·阿兰基斯"
},
{
"predicate": "出生日期",
"object_type": "Date",
"subject_type": "人物",
"object": "1989年4月17日",
"subject": "查尔斯·阿兰基斯"
}
]
}
```
</details>
### 2. 模型
| 模型 | 论文 | 备注 |
|---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------|
| [casrel](lightningnlp/task/relation_extraction/casrel/model.py) | [A Novel Cascade Binary Tagging Framework for Relational Triple Extraction.](https://aclanthology.org/2020.acl-main.136.pdf) | 两阶段关系抽取,先抽取出句子中的主语,再通过指针网络抽取出主语对应的关系和宾语 |
| [tplinker](lightningnlp/task/relation_extraction/tplinker/model.py) | [TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking.](https://aclanthology.org/2020.coling-main.138.pdf) | 将关系抽取问题转换为主语-宾语的首尾连接问题 |
| [spn](lightningnlp/task/relation_extraction/spn/model.py) | [Joint Entity and Relation Extraction with Set Prediction Networks.](http://xxx.itp.ac.cn/pdf/2011.01675v2) | 将关系抽取问题转为为三元组的集合预测问题,采用 `Encoder-Decoder` 框架 |
| [prgc](lightningnlp/task/relation_extraction/prgc/model.py) | [PRGC: Potential Relation and Global Correspondence Based Joint Relational Triple Extraction.](https://aclanthology.org/2021.acl-long.486.pdf) | 先抽取句子的潜在关系类型,然后对于给定关系抽取出对应的主语-宾语对,最后通过全局对齐模块过滤 |
| [pfn](lightningnlp/task/relation_extraction/pfn/model.py) | [A Partition Filter Network for Joint Entity and Relation Extraction.](https://aclanthology.org/2021.emnlp-main.17.pdf) | 采用类似 `LSTM` 的分区过滤机制,将隐藏层信息分成实体识别、关系识别和共享三部分,对与不同的任务利用不同的信息 |
| [grte](lightningnlp/task/relation_extraction/grte/model.py) | [A Novel Global Feature-Oriented Relational Triple Extraction Model based on Table Filling.](https://aclanthology.org/2021.emnlp-main.208.pdf) | 将关系抽取问题转换为单词对的分类问题,基于全局特征抽取模块循环优化单词对的向量表示 |
| [gplinker](lightningnlp/task/relation_extraction/gplinker/model.py) | | [GPLinker:基于GlobalPointer的实体关系联合抽取](https://kexue.fm/archives/8888) |
<details>
<summary>训练代码示例</summary>
```python
import os
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from transformers import BertTokenizerFast
from lightningnlp.callbacks import LoggingCallback
from lightningnlp.task.relation_extraction import (
GPLinkerDataModule,
RelationExtractionTransformer,
)
pl.seed_everything(seed=42)
pretrained_model_name_or_path = "hfl/chinese-roberta-wwm-ext" # 预训练模型
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
dm = GPLinkerDataModule(
tokenizer=tokenizer,
train_batch_size=16, # 训练集batch_size
validation_batch_size=16, # 验证集batch_size
num_workers=16, # 多进程加载数据
dataset_name="xusenlin/duie", # huggingface数据集
train_max_length=256,
validation_max_length=256,
cache_dir="datasets/duie", # 数据缓存路径
task_name="duie-bert-gplinker", # 自定义任务名称
is_chinese=True,
)
model = RelationExtractionTransformer(
downstream_model_name="gplinker", # 模型名称
downstream_model_type="bert", # 预训练模型类型
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer=tokenizer,
predicates=dm.predicate_list,
learning_rate=2e-5,
average="micro",
output_dir="outputs/duie/gplinker", # 模型保存路径
)
model_ckpt = pl.callbacks.ModelCheckpoint(
dirpath="outputs/duie/gplinker",
filename="best_model",
monitor="val_f1_micro",
save_top_k=1,
mode="max",
)
wandb_logger = WandbLogger(project="Relation Extraction", name="duie-bert-gplinker")
trainer = pl.Trainer(
logger=wandb_logger,
accelerator="gpu",
devices=1,
max_epochs=12,
val_check_interval=0.5,
gradient_clip_val=1.0,
callbacks=[model_ckpt, LoggingCallback()]
)
trainer.fit(model, dm)
```
</details>
### 3. 预测
本项目在 [huggingface](https://huggingface.co/xusenlin/duie-gplinker) 上提供了一个训练好的模型作为示例可供测试和使用,运行以下代码会自动下载模型并进行预测
```python
from pprint import pprint
from lightningnlp.task.relation_extraction import RelationExtractionPipeline
pipline = RelationExtractionPipeline(model_name_or_path="xusenlin/duie-gplinker", model_name="gplinker", model_type="bert")
text = "查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部。"
pprint(pipline(text))
```
### 4. APP应用
![re](./images/re.png)
## 🍭 通用信息抽取
+ [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf):Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架 `UIE`。
+ 该框架实现了实体抽取、关系抽取、事件抽取、情感分析等任务的统一建模,并使得不同任务间具备良好的迁移和泛化能力。
+ 为了方便大家使用 `UIE` 的强大能力,[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)借鉴该论文的方法,基于 `ERNIE 3.0` 知识增强预训练模型,训练并开源了首个中文通用信息抽取模型 `UIE`。
+ 该模型可以支持不限定行业领域和抽取目标的关键信息抽取,实现零样本快速冷启动,并具备优秀的小样本微调能力,快速适配特定的抽取目标。
![uie](./images/uie.png)
<details>
<summary>👉 命名实体识别</summary>
```python
from pprint import pprint
from lightningnlp.task.uie import UIEPredictor
# 实体识别
schema = ['时间', '选手', '赛事名称']
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPredictor("xusenlin/uie-base", schema=schema)
pprint(uie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint
```
输出:
```text
[{'时间': [{'end': 6,
'probability': 0.9857378532924486,
'start': 0,
'text': '2月8日上午'}],
'赛事名称': [{'end': 23,
'probability': 0.8503089953268272,
'start': 6,
'text': '北京冬奥会自由式滑雪女子大跳台决赛'}],
'选手': [{'end': 31,
'probability': 0.8981548639781138,
'start': 28,
'text': '谷爱凌'}]}]
```
</details>
<details>
<summary>👉 实体关系抽取</summary>
```python
from pprint import pprint
from lightningnlp.task.uie import UIEPredictor
# 关系抽取
schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPredictor("xusenlin/uie-base", schema=schema)
pprint(uie("2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。")) # Better print results using pprint
```
输出:
```text
[{'竞赛名称': [{'end': 13,
'probability': 0.7825402622754041,
'relations': {'主办方': [{'end': 22,
'probability': 0.8421710521379353,
'start': 14,
'text': '中国中文信息学会'},
{'end': 30,
'probability': 0.7580801847701935,
'start': 23,
'text': '中国计算机学会'}],
'已举办次数': [{'end': 82,
'probability': 0.4671295049136148,
'start': 80,
'text': '4届'}],
'承办方': [{'end': 39,
'probability': 0.8292706618236352,
'start': 35,
'text': '百度公司'},
{'end': 72,
'probability': 0.6193477885474685,
'start': 56,
'text': '中国计算机学会自然语言处理专委会'},
{'end': 55,
'probability': 0.7000497331473241,
'start': 40,
'text': '中国中文信息学会评测工作委员会'}]},
'start': 0,
'text': '2022语言与智能技术竞赛'}]}]
```
</details>
<details>
<summary>👉 事件抽取</summary>
```python
from pprint import pprint
from lightningnlp.task.uie import UIEPredictor
# 事件抽取
schema = {"地震触发词": ["地震强度", "时间", "震中位置", "震源深度"]}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPredictor("xusenlin/uie-base", schema=schema)
pprint(uie("中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。")) # Better print results using pprint
```
输出:
```text
[{'地震触发词': {'end': 58,
'probability': 0.9977425932884216,
'relation': {'地震强度': [{'end': 56,
'probability': 0.9980800747871399,
'start': 52,
'text': '3.5级'}],
'时间': [{'end': 22,
'probability': 0.9853301644325256,
'start': 11,
'text': '5月16日06时08分'}],
'震中位置': [{'end': 50,
'probability': 0.7874020934104919,
'start': 23,
'text': '云南临沧市凤庆县(北纬24.34度,东经99.98度)'}],
'震源深度': [{'end': 67,
'probability': 0.9937973618507385,
'start': 63,
'text': '10千米'}]},
'start': 56,
'text': '地震'}}]
```
</details>
<details>
<summary>👉 评论观点抽取</summary>
```python
from pprint import pprint
from lightningnlp.task.uie import UIEPredictor
# 事件抽取
schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']}
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPredictor("xusenlin/uie-base", schema=schema)
pprint(uie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")) # Better print results using pprint
```
输出:
```text
[{'评价维度': [{'end': 20,
'probability': 0.9817040258681473,
'relations': {'情感倾向[正向,负向]': [{'probability': 0.9966142505350533,
'text': '正向'}],
'观点词': [{'end': 22,
'probability': 0.957396472711558,
'start': 21,
'text': '高'}]},
'start': 17,
'text': '性价比'},
{'end': 2,
'probability': 0.9696849569741168,
'relations': {'情感倾向[正向,负向]': [{'probability': 0.9982153274927796,
'text': '正向'}],
'观点词': [{'end': 4,
'probability': 0.9945318044652538,
'start': 2,
'text': '干净'}]},
'start': 0,
'text': '店面'}]}]
```
</details>
<details>
<summary>👉 情感分类</summary>
```python
from pprint import pprint
from lightningnlp.task.uie import UIEPredictor
# 事件抽取
schema = '情感倾向[正向,负向]'
# uie-base模型已上传至huggingface,可自动下载,其他模型只需提供模型名称将自动进行转换
uie = UIEPredictor("xusenlin/uie-base", schema=schema)
pprint(uie("这个产品用起来真的很流畅,我非常喜欢")) # Better print results using pprint
```
输出:
```text
[{'情感倾向[正向,负向]': {'end': 0,
'probability': 0.9990023970603943,
'start': 0,
'text': '正向'}}]
```
</details>
## Citation
如果 `LightningNLP` 对您的研究有帮助,欢迎引用
```text
@misc{=lightningnlp,
title={LightningNLP: An Easy-to-use NLP Library},
author={senlin xu},
howpublished = {\url{https://github.com/xusenlinzy/lightningblocks}},
year={2022}
}
```
## Acknowledge
我们借鉴了[`Lightning-transformers`](https://github.com/Lightning-AI/lightning-transformers) 关于模型使用的优秀设计,在此对`Lightning-transformers` 作者及其开源社区表示感谢。
Raw data
{
"_id": null,
"home_page": "https://github.com/xusenlinzy/lightningblocks",
"name": "lightningnlp",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3.7",
"maintainer_email": "",
"keywords": "",
"author": "xusenlin",
"author_email": "1659821119@qq.com",
"download_url": "https://files.pythonhosted.org/packages/ad/ae/4cbae6316dcb022c204245c1e3e50f321ec575b943532a9bf8bb9c7224d6/lightningnlp-1.0.2.tar.gz",
"platform": null,
"description": "<div align=\"center\">\r\n\r\n**\u57fa\u4e8e [`pytorch-lightning`](https://github.com/Lightning-AI/lightning) \u548c [`transformers`](https://github.com/huggingface/transformers) \u6846\u67b6\u5b9e\u73b0\u5404\u7c7b `NLP` \u4efb\u52a1**\r\n\r\n</div>\r\n\r\n## \ud83d\udd28 \u5b89\u88c5\r\n\r\n1. \u642d\u5efa\u597d `pytorch GPU` \u6df1\u5ea6\u5b66\u4e60\u73af\u5883\r\n\r\n```bash\r\nconda create -n pytorch python=3.8\r\nconda activate pytorch\r\nconda install pytorch cudatoolkit -c pytorch\r\n```\r\n\r\n2. \u5b89\u88c5 `lightningnlp`\r\n\r\n```bash\r\npip install lightningnlp\r\n```\r\n\r\n3. \u5728 `https://pytorch-geometric.com/whl/` \u4e2d\u627e\u5230\u4e0e `torch` \u7248\u672c\u5bf9\u5e94\u7684 `torch_scatter`\uff0c\u4e0b\u8f7d\u540e\u4f7f\u7528 `pip` \u5b89\u88c5\u5230\u73af\u5883\u4e2d \r\n\r\n```python\r\nimport torch\r\nprint(torch.__version__) # 1.12.0\r\nprint(torch.version.cuda) # 11.3\r\n```\r\n\r\n```bash\r\n# \u4ee5python=3.8, torch=1.12.0, cuda=11.3\u4e3a\u4f8b\r\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp38-cp38-linux_x86_64.whl\r\npip install torch_scatter-2.1.0+pt112cu113-cp38-cp38-linux_x86_64.whl\r\n```\r\n\r\n\u672c\u9879\u76ee\u4e5f\u63d0\u4f9b\u4e86[docker\u5b89\u88c5\u65b9\u5f0f](./docker)\r\n\r\n## \ud83e\uddfe \u6587\u672c\u5206\u7c7b\r\n\r\n### 1. \u6570\u636e\u683c\u5f0f\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u6570\u636e\u793a\u4f8b</summary>\r\n\r\n```json\r\n{\r\n \"text\": \"\u4ee5\u8272\u5217\u5927\u89c4\u6a21\u7a7a\u88ad\u5f00\u59cb\uff01\u4f0a\u6717\u591a\u4e2a\u519b\u4e8b\u76ee\u6807\u906d\u9047\u6253\u51fb\uff0c\u8a93\u8a00\u5bf9\u7b49\u53cd\u51fb\",\r\n \"label\": \"news_military\"\r\n}\r\n```\r\n\r\n</details>\r\n\r\n### 2. \u6a21\u578b\r\n\r\n| \u6a21\u578b | \u8bba\u6587 | \u5907\u6ce8 |\r\n|-----------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|---------------------------------|\r\n| [fc](lightningnlp/task/text_classification/fc/model.py) | | \u5168\u8fde\u63a5\u5c42\u5206\u7c7b |\r\n| [mdp](lightningnlp/task/text_classification/mdp/model.py) | [Multi-Sample Dropout for Accelerated Training and Better Generalization.](https://arxiv.org/abs/1905.09788) | \u4f7f\u7528 `MultiSampleDropout`\uff0c\u7c7b\u4f3c\u4e8e\u6a21\u578b\u878d\u5408 |\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u4ee3\u7801\u793a\u4f8b</summary>\r\n\r\n```python\r\nimport os\r\nos.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'\r\n\r\nimport pytorch_lightning as pl\r\nfrom pytorch_lightning.loggers import WandbLogger # \u9700\u8981\u5b89\u88c5wandb\r\nfrom transformers import BertTokenizerFast\r\n\r\nfrom lightningnlp.callbacks import LoggingCallback\r\nfrom lightningnlp.task.text_classification import (\r\n TextClassificationDataModule,\r\n TextClassificationTransformer,\r\n)\r\n\r\npl.seed_everything(seed=42)\r\npretrained_model_name_or_path = \"hfl/chinese-roberta-wwm-ext\" # \u9884\u8bad\u7ec3\u6a21\u578b\r\ntokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\r\n\r\ndm = TextClassificationDataModule(\r\n tokenizer=tokenizer,\r\n train_batch_size=16, # \u8bad\u7ec3\u96c6batch_size\r\n validation_batch_size=16, # \u9a8c\u8bc1\u96c6batch_size\r\n num_workers=16, # \u591a\u8fdb\u7a0b\u52a0\u8f7d\u6570\u636e\r\n dataset_name=\"datasets/sentiment\", # \u8bad\u7ec3\u6570\u636e\u6240\u5728\u76ee\u5f55\r\n train_file=\"train.json\", # \u8bad\u7ec3\u96c6\u6587\u4ef6\u540d\r\n validation_file=\"dev.json\", # \u9a8c\u8bc1\u96c6\u6587\u4ef6\u540d\r\n train_max_length=256,\r\n cache_dir=\"datasets/sentiment\", # \u6570\u636e\u7f13\u5b58\u8def\u5f84\r\n)\r\n\r\nmodel = TextClassificationTransformer(\r\n downstream_model_name=\"fc\", # \u6a21\u578b\u540d\u79f0\r\n downstream_model_type=\"bert\", # \u9884\u8bad\u7ec3\u6a21\u578b\u7c7b\u578b\r\n pretrained_model_name_or_path=pretrained_model_name_or_path,\r\n tokenizer=tokenizer,\r\n label_map=dm.label_map,\r\n learning_rate=2e-5,\r\n output_dir=\"outputs/sentiment/fc\", # \u6a21\u578b\u4fdd\u5b58\u8def\u5f84\r\n)\r\n\r\nmodel_ckpt = pl.callbacks.ModelCheckpoint(\r\n dirpath=\"outputs/sentiment/fc\",\r\n filename=\"best_model\",\r\n monitor=\"val_accuracy\",\r\n save_top_k=1,\r\n mode=\"max\",\r\n)\r\n\r\nwandb_logger = WandbLogger(project=\"Text Classification\", name=\"fc\")\r\n\r\ntrainer = pl.Trainer(\r\n logger=wandb_logger,\r\n accelerator=\"gpu\",\r\n devices=1,\r\n max_epochs=12,\r\n val_check_interval=0.5,\r\n gradient_clip_val=1.0,\r\n callbacks=[model_ckpt, LoggingCallback()]\r\n)\r\n\r\ntrainer.fit(model, dm)\r\n```\r\n\r\n</details>\r\n\r\n### 3. \u9884\u6d4b\r\n\r\n```python\r\nfrom lightningnlp.task.text_classification import TextClassificationTransformer\r\n\r\nmodel = TextClassificationTransformer.load_from_checkpoint(\"my_bert_model_path\")\r\ntext = \"\u4ee5\u8272\u5217\u5927\u89c4\u6a21\u7a7a\u88ad\u5f00\u59cb\uff01\u4f0a\u6717\u591a\u4e2a\u519b\u4e8b\u76ee\u6807\u906d\u9047\u6253\u51fb\uff0c\u8a93\u8a00\u5bf9\u7b49\u53cd\u51fb\"\r\nprint(model.predict(text))\r\n```\r\n\r\n\r\n## \ud83d\udcc4 \u547d\u540d\u5b9e\u4f53\u8bc6\u522b\r\n\r\n### 1. \u6570\u636e\u683c\u5f0f\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u6570\u636e\u793a\u4f8b</summary>\r\n\r\n```json\r\n{\r\n \"text\": \"\u7ed3\u679c\u4e0a\u5468\u516d\u4ed6\u4eec\u4e3b\u573a0\uff1a3\u60e8\u8d25\u7ed9\u4e86\u4e2d\u6e38\u7403\u961f\u74e6\u62c9\u591a\u5229\u5fb7\uff0c\u8fd17\u4e2a\u591a\u6708\u4ee5\u6765\u897f\u7532\u9996\u6b21\u8f93\u7403\u3002\", \r\n \"entities\": [\r\n {\r\n \"id\": 0, \r\n \"entity\": \"\u74e6\u62c9\u591a\u5229\u5fb7\", \r\n \"start_offset\": 20, \r\n \"end_offset\": 25, \r\n \"label\": \"organization\"\r\n }, \r\n {\r\n \"id\": 1, \r\n \"entity\": \"\u897f\u7532\", \r\n \"start_offset\": 33, \r\n \"end_offset\": 35, \r\n \"label\": \"organization\"\r\n }\r\n ]\r\n}\r\n```\r\n</details>\r\n\r\n\r\n### 2. \u6a21\u578b\r\n\r\n\r\n| \u6a21\u578b | \u8bba\u6587 | \u5907\u6ce8 |\r\n|--------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|\r\n| [softmax](lightningnlp/task/named_entity_recognition/crf/model.py) | | \u5168\u8fde\u63a5\u5c42\u5e8f\u5217\u6807\u6ce8\u5e76\u4f7f\u7528 `BIO` \u89e3\u7801 |\r\n| [crf](lightningnlp/task/named_entity_recognition/crf/model.py) | [Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers) | \u5168\u8fde\u63a5\u5c42+\u6761\u4ef6\u968f\u673a\u573a\uff0c\u5e76\u4f7f\u7528 `BIO` \u89e3\u7801 |\r\n| [cascade-crf](lightningnlp/task/named_entity_recognition/crf/model.py) | | \u5148\u9884\u6d4b\u5b9e\u4f53\u518d\u9884\u6d4b\u5b9e\u4f53\u7c7b\u578b |\r\n| [span](lightningnlp/task/named_entity_recognition/span/model.py) | | \u4f7f\u7528\u4e24\u4e2a\u6307\u9488\u7f51\u7edc\u9884\u6d4b\u5b9e\u4f53\u8d77\u59cb\u4f4d\u7f6e |\r\n| [global-pointer](lightningnlp/task/named_entity_recognition/global_pointer/model.py) | | [GlobalPointer\uff1a\u7528\u7edf\u4e00\u7684\u65b9\u5f0f\u5904\u7406\u5d4c\u5957\u548c\u975e\u5d4c\u5957NER](https://spaces.ac.cn/archives/8373)\u3001[Efficient GlobalPointer\uff1a\u5c11\u70b9\u53c2\u6570\uff0c\u591a\u70b9\u6548\u679c](https://spaces.ac.cn/archives/8877) |\r\n| [mrc](lightningnlp/task/named_entity_recognition/mrc/model.py) | [A Unified MRC Framework for Named Entity Recognition.](https://aclanthology.org/2020.acl-main.519.pdf) | \u5c06\u5b9e\u4f53\u8bc6\u522b\u4efb\u52a1\u8f6c\u6362\u4e3a\u9605\u8bfb\u7406\u89e3\u95ee\u9898\uff0c\u8f93\u5165\u4e3a\u5b9e\u4f53\u7c7b\u578b\u6a21\u677f+\u53e5\u5b50\uff0c\u9884\u6d4b\u5bf9\u5e94\u5b9e\u4f53\u7684\u8d77\u59cb\u4f4d\u7f6e |\r\n| [tplinker](lightningnlp/task/named_entity_recognition/tplinker/model.py) | [TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking.](https://aclanthology.org/2020.coling-main.138.pdf) | \u5c06\u5b9e\u4f53\u8bc6\u522b\u4efb\u52a1\u8f6c\u6362\u4e3a\u8868\u683c\u586b\u5145\u95ee\u9898 |\r\n| [lear](lightningnlp/task/named_entity_recognition/lear/model.py) | [Enhanced Language Representation with Label Knowledge for Span Extraction.](https://aclanthology.org/2021.emnlp-main.379.pdf) | \u6539\u8fdb `MRC` \u65b9\u6cd5\u6548\u7387\u95ee\u9898\uff0c\u91c7\u7528\u6807\u7b7e\u878d\u5408\u673a\u5236 |\r\n| [w2ner](lightningnlp/task/named_entity_recognition/w2ner/model.py) | [Unified Named Entity Recognition as Word-Word Relation Classification.](https://arxiv.org/pdf/2112.10070.pdf) | \u7edf\u4e00\u89e3\u51b3\u5d4c\u5957\u5b9e\u4f53\u3001\u4e0d\u8fde\u7eed\u5b9e\u4f53\u7684\u62bd\u53d6\u95ee\u9898 |\r\n| [cnn](lightningnlp/task/named_entity_recognition/cnn/model.py) | [An Embarrassingly Easy but Strong Baseline for Nested Named Entity Recognition.](https://arxiv.org/abs/2208.04534) | \u6539\u8fdb `W2NER` \u65b9\u6cd5\uff0c\u91c7\u7528\u5377\u79ef\u7f51\u7edc\u63d0\u53d6\u5b9e\u4f53\u5185\u90e8token\u4e4b\u95f4\u7684\u5173\u7cfb |\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u4ee3\u7801\u793a\u4f8b</summary>\r\n\r\n```python\r\nimport os\r\nos.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'\r\n\r\nimport pytorch_lightning as pl\r\nfrom pytorch_lightning.loggers import WandbLogger\r\nfrom transformers import BertTokenizerFast\r\n\r\nfrom lightningnlp.callbacks import LoggingCallback\r\nfrom lightningnlp.task.named_entity_recognition import (\r\n CRFNerDataModule,\r\n NamedEntityRecognitionTransformer,\r\n)\r\n\r\npl.seed_everything(seed=42)\r\npretrained_model_name_or_path = \"hfl/chinese-roberta-wwm-ext\" # \u9884\u8bad\u7ec3\u6a21\u578b\r\ntokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\r\n\r\ndm = CRFNerDataModule(\r\n tokenizer=tokenizer,\r\n train_batch_size=16, # \u8bad\u7ec3\u96c6batch_size\r\n validation_batch_size=16, # \u9a8c\u8bc1\u96c6batch_size\r\n num_workers=16, # \u591a\u8fdb\u7a0b\u52a0\u8f7d\u6570\u636e\r\n dataset_name=\"xusenlin/cmeee\", # huggingface\u6570\u636e\u96c6\r\n train_max_length=256,\r\n validation_max_length=256,\r\n cache_dir=\"datasets/cmeee\", # \u6570\u636e\u7f13\u5b58\u8def\u5f84\r\n task_name=\"cmeee-bert-crf\", # \u81ea\u5b9a\u4e49\u4efb\u52a1\u540d\u79f0\r\n is_chinese=True,\r\n)\r\n\r\nmodel = NamedEntityRecognitionTransformer(\r\n downstream_model_name=\"crf\", # \u6a21\u578b\u540d\u79f0\r\n downstream_model_type=\"bert\", # \u9884\u8bad\u7ec3\u6a21\u578b\u7c7b\u578b\r\n pretrained_model_name_or_path=pretrained_model_name_or_path,\r\n tokenizer=tokenizer,\r\n labels=dm.label_list,\r\n learning_rate=2e-5,\r\n average=\"micro\",\r\n output_dir=\"outputs/cmeee/crf\", # \u6a21\u578b\u4fdd\u5b58\u8def\u5f84\r\n)\r\n\r\nmodel_ckpt = pl.callbacks.ModelCheckpoint(\r\n dirpath=\"outputs/cmeee/crf\",\r\n filename=\"best_model\",\r\n monitor=\"val_f1_micro\",\r\n save_top_k=1,\r\n mode=\"max\",\r\n)\r\n\r\nwandb_logger = WandbLogger(project=\"Named Entity Recognition\", name=\"cmeee-bert-crf\")\r\n\r\ntrainer = pl.Trainer(\r\n logger=wandb_logger,\r\n accelerator=\"gpu\",\r\n devices=1,\r\n max_epochs=12,\r\n val_check_interval=0.5,\r\n gradient_clip_val=1.0,\r\n callbacks=[model_ckpt, LoggingCallback()]\r\n)\r\n\r\ntrainer.fit(model, dm)\r\n```\r\n\r\n</details>\r\n\r\n### 3. \u9884\u6d4b\r\n\r\n\u672c\u9879\u76ee\u5728 [huggingface](https://huggingface.co/xusenlin/cmeee-global-pointer) \u4e0a\u63d0\u4f9b\u4e86\u4e00\u4e2a\u8bad\u7ec3\u597d\u7684\u6a21\u578b\u4f5c\u4e3a\u793a\u4f8b\u53ef\u4f9b\u6d4b\u8bd5\u548c\u4f7f\u7528\uff0c\u8fd0\u884c\u4ee5\u4e0b\u4ee3\u7801\u4f1a\u81ea\u52a8\u4e0b\u8f7d\u6a21\u578b\u5e76\u8fdb\u884c\u9884\u6d4b\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.named_entity_recognition import NerPipeline\r\n\r\npipline = NerPipeline(model_name_or_path=\"xusenlin/cmeee-global-pointer\", model_name=\"global-pointer\", model_type=\"bert\")\r\ntext = \"\u7ed3\u679c\u4e0a\u5468\u516d\u4ed6\u4eec\u4e3b\u573a0\uff1a3\u60e8\u8d25\u7ed9\u4e86\u4e2d\u6e38\u7403\u961f\u74e6\u62c9\u591a\u5229\u5fb7\uff0c\u8fd17\u4e2a\u591a\u6708\u4ee5\u6765\u897f\u7532\u9996\u6b21\u8f93\u7403\u3002\"\r\npprint(pipline(text))\r\n```\r\n\r\n### 4. APP\u5e94\u7528\r\n\r\n![ner](./images/ner.png)\r\n\r\n\r\n## \ud83d\udd16 \u5b9e\u4f53\u5173\u7cfb\u62bd\u53d6\r\n\r\n### 1. \u6570\u636e\u683c\u5f0f\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u6570\u636e\u793a\u4f8b</summary>\r\n\r\n```json\r\n{\r\n \"text\": \"\u67e5\u5c14\u65af\u00b7\u963f\u5170\u57fa\u65af\uff08Charles Ar\u00e1nguiz\uff09\uff0c1989\u5e744\u670817\u65e5\u51fa\u751f\u4e8e\u667a\u5229\u5723\u5730\u4e9a\u54e5\uff0c\u667a\u5229\u804c\u4e1a\u8db3\u7403\u8fd0\u52a8\u5458\uff0c\u53f8\u804c\u4e2d\u573a\uff0c\u6548\u529b\u4e8e\u5fb7\u56fd\u8db3\u7403\u7532\u7ea7\u8054\u8d5b\u52d2\u6c83\u5e93\u68ee\u8db3\u7403\u4ff1\u4e50\u90e8\", \r\n \"spo_list\": [\r\n {\r\n \"predicate\": \"\u51fa\u751f\u5730\", \r\n \"object_type\": \"\u5730\u70b9\", \r\n \"subject_type\": \"\u4eba\u7269\", \r\n \"object\": \"\u5723\u5730\u4e9a\u54e5\", \r\n \"subject\": \"\u67e5\u5c14\u65af\u00b7\u963f\u5170\u57fa\u65af\"\r\n }, \r\n {\r\n \"predicate\": \"\u51fa\u751f\u65e5\u671f\", \r\n \"object_type\": \"Date\", \r\n \"subject_type\": \"\u4eba\u7269\", \r\n \"object\": \"1989\u5e744\u670817\u65e5\",\r\n \"subject\": \"\u67e5\u5c14\u65af\u00b7\u963f\u5170\u57fa\u65af\"\r\n }\r\n ]\r\n}\r\n```\r\n\r\n</details>\r\n\r\n\r\n### 2. \u6a21\u578b\r\n\r\n| \u6a21\u578b | \u8bba\u6587 | \u5907\u6ce8 |\r\n|---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------|\r\n| [casrel](lightningnlp/task/relation_extraction/casrel/model.py) | [A Novel Cascade Binary Tagging Framework for Relational Triple Extraction.](https://aclanthology.org/2020.acl-main.136.pdf) | \u4e24\u9636\u6bb5\u5173\u7cfb\u62bd\u53d6\uff0c\u5148\u62bd\u53d6\u51fa\u53e5\u5b50\u4e2d\u7684\u4e3b\u8bed\uff0c\u518d\u901a\u8fc7\u6307\u9488\u7f51\u7edc\u62bd\u53d6\u51fa\u4e3b\u8bed\u5bf9\u5e94\u7684\u5173\u7cfb\u548c\u5bbe\u8bed |\r\n| [tplinker](lightningnlp/task/relation_extraction/tplinker/model.py) | [TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking.](https://aclanthology.org/2020.coling-main.138.pdf) | \u5c06\u5173\u7cfb\u62bd\u53d6\u95ee\u9898\u8f6c\u6362\u4e3a\u4e3b\u8bed-\u5bbe\u8bed\u7684\u9996\u5c3e\u8fde\u63a5\u95ee\u9898 |\r\n| [spn](lightningnlp/task/relation_extraction/spn/model.py) | [Joint Entity and Relation Extraction with Set Prediction Networks.](http://xxx.itp.ac.cn/pdf/2011.01675v2) | \u5c06\u5173\u7cfb\u62bd\u53d6\u95ee\u9898\u8f6c\u4e3a\u4e3a\u4e09\u5143\u7ec4\u7684\u96c6\u5408\u9884\u6d4b\u95ee\u9898\uff0c\u91c7\u7528 `Encoder-Decoder` \u6846\u67b6 |\r\n| [prgc](lightningnlp/task/relation_extraction/prgc/model.py) | [PRGC: Potential Relation and Global Correspondence Based Joint Relational Triple Extraction.](https://aclanthology.org/2021.acl-long.486.pdf) | \u5148\u62bd\u53d6\u53e5\u5b50\u7684\u6f5c\u5728\u5173\u7cfb\u7c7b\u578b\uff0c\u7136\u540e\u5bf9\u4e8e\u7ed9\u5b9a\u5173\u7cfb\u62bd\u53d6\u51fa\u5bf9\u5e94\u7684\u4e3b\u8bed-\u5bbe\u8bed\u5bf9\uff0c\u6700\u540e\u901a\u8fc7\u5168\u5c40\u5bf9\u9f50\u6a21\u5757\u8fc7\u6ee4 |\r\n| [pfn](lightningnlp/task/relation_extraction/pfn/model.py) | [A Partition Filter Network for Joint Entity and Relation Extraction.](https://aclanthology.org/2021.emnlp-main.17.pdf) | \u91c7\u7528\u7c7b\u4f3c `LSTM` \u7684\u5206\u533a\u8fc7\u6ee4\u673a\u5236\uff0c\u5c06\u9690\u85cf\u5c42\u4fe1\u606f\u5206\u6210\u5b9e\u4f53\u8bc6\u522b\u3001\u5173\u7cfb\u8bc6\u522b\u548c\u5171\u4eab\u4e09\u90e8\u5206\uff0c\u5bf9\u4e0e\u4e0d\u540c\u7684\u4efb\u52a1\u5229\u7528\u4e0d\u540c\u7684\u4fe1\u606f |\r\n| [grte](lightningnlp/task/relation_extraction/grte/model.py) | [A Novel Global Feature-Oriented Relational Triple Extraction Model based on Table Filling.](https://aclanthology.org/2021.emnlp-main.208.pdf) | \u5c06\u5173\u7cfb\u62bd\u53d6\u95ee\u9898\u8f6c\u6362\u4e3a\u5355\u8bcd\u5bf9\u7684\u5206\u7c7b\u95ee\u9898\uff0c\u57fa\u4e8e\u5168\u5c40\u7279\u5f81\u62bd\u53d6\u6a21\u5757\u5faa\u73af\u4f18\u5316\u5355\u8bcd\u5bf9\u7684\u5411\u91cf\u8868\u793a |\r\n| [gplinker](lightningnlp/task/relation_extraction/gplinker/model.py) | | [GPLinker\uff1a\u57fa\u4e8eGlobalPointer\u7684\u5b9e\u4f53\u5173\u7cfb\u8054\u5408\u62bd\u53d6](https://kexue.fm/archives/8888) |\r\n\r\n\r\n<details>\r\n<summary>\u8bad\u7ec3\u4ee3\u7801\u793a\u4f8b</summary>\r\n\r\n```python\r\nimport os\r\nos.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'\r\n\r\nimport pytorch_lightning as pl\r\nfrom pytorch_lightning.loggers import WandbLogger\r\nfrom transformers import BertTokenizerFast\r\n\r\nfrom lightningnlp.callbacks import LoggingCallback\r\nfrom lightningnlp.task.relation_extraction import (\r\n GPLinkerDataModule,\r\n RelationExtractionTransformer,\r\n)\r\n\r\npl.seed_everything(seed=42)\r\npretrained_model_name_or_path = \"hfl/chinese-roberta-wwm-ext\" # \u9884\u8bad\u7ec3\u6a21\u578b\r\ntokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)\r\n\r\ndm = GPLinkerDataModule(\r\n tokenizer=tokenizer,\r\n train_batch_size=16, # \u8bad\u7ec3\u96c6batch_size\r\n validation_batch_size=16, # \u9a8c\u8bc1\u96c6batch_size\r\n num_workers=16, # \u591a\u8fdb\u7a0b\u52a0\u8f7d\u6570\u636e\r\n dataset_name=\"xusenlin/duie\", # huggingface\u6570\u636e\u96c6\r\n train_max_length=256,\r\n validation_max_length=256,\r\n cache_dir=\"datasets/duie\", # \u6570\u636e\u7f13\u5b58\u8def\u5f84\r\n task_name=\"duie-bert-gplinker\", # \u81ea\u5b9a\u4e49\u4efb\u52a1\u540d\u79f0\r\n is_chinese=True,\r\n)\r\n\r\nmodel = RelationExtractionTransformer(\r\n downstream_model_name=\"gplinker\", # \u6a21\u578b\u540d\u79f0\r\n downstream_model_type=\"bert\", # \u9884\u8bad\u7ec3\u6a21\u578b\u7c7b\u578b\r\n pretrained_model_name_or_path=pretrained_model_name_or_path,\r\n tokenizer=tokenizer,\r\n predicates=dm.predicate_list,\r\n learning_rate=2e-5,\r\n average=\"micro\",\r\n output_dir=\"outputs/duie/gplinker\", # \u6a21\u578b\u4fdd\u5b58\u8def\u5f84\r\n)\r\n\r\nmodel_ckpt = pl.callbacks.ModelCheckpoint(\r\n dirpath=\"outputs/duie/gplinker\",\r\n filename=\"best_model\",\r\n monitor=\"val_f1_micro\",\r\n save_top_k=1,\r\n mode=\"max\",\r\n)\r\n\r\nwandb_logger = WandbLogger(project=\"Relation Extraction\", name=\"duie-bert-gplinker\")\r\n\r\ntrainer = pl.Trainer(\r\n logger=wandb_logger,\r\n accelerator=\"gpu\",\r\n devices=1,\r\n max_epochs=12,\r\n val_check_interval=0.5,\r\n gradient_clip_val=1.0,\r\n callbacks=[model_ckpt, LoggingCallback()]\r\n)\r\n\r\ntrainer.fit(model, dm)\r\n```\r\n\r\n</details>\r\n\r\n### 3. \u9884\u6d4b\r\n\r\n\u672c\u9879\u76ee\u5728 [huggingface](https://huggingface.co/xusenlin/duie-gplinker) \u4e0a\u63d0\u4f9b\u4e86\u4e00\u4e2a\u8bad\u7ec3\u597d\u7684\u6a21\u578b\u4f5c\u4e3a\u793a\u4f8b\u53ef\u4f9b\u6d4b\u8bd5\u548c\u4f7f\u7528\uff0c\u8fd0\u884c\u4ee5\u4e0b\u4ee3\u7801\u4f1a\u81ea\u52a8\u4e0b\u8f7d\u6a21\u578b\u5e76\u8fdb\u884c\u9884\u6d4b\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.relation_extraction import RelationExtractionPipeline\r\n\r\npipline = RelationExtractionPipeline(model_name_or_path=\"xusenlin/duie-gplinker\", model_name=\"gplinker\", model_type=\"bert\")\r\ntext = \"\u67e5\u5c14\u65af\u00b7\u963f\u5170\u57fa\u65af\uff08Charles Ar\u00e1nguiz\uff09\uff0c1989\u5e744\u670817\u65e5\u51fa\u751f\u4e8e\u667a\u5229\u5723\u5730\u4e9a\u54e5\uff0c\u667a\u5229\u804c\u4e1a\u8db3\u7403\u8fd0\u52a8\u5458\uff0c\u53f8\u804c\u4e2d\u573a\uff0c\u6548\u529b\u4e8e\u5fb7\u56fd\u8db3\u7403\u7532\u7ea7\u8054\u8d5b\u52d2\u6c83\u5e93\u68ee\u8db3\u7403\u4ff1\u4e50\u90e8\u3002\"\r\npprint(pipline(text))\r\n```\r\n\r\n### 4. APP\u5e94\u7528\r\n\r\n![re](./images/re.png)\r\n\r\n\r\n## \ud83c\udf6d \u901a\u7528\u4fe1\u606f\u62bd\u53d6\r\n\r\n+ [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf)\uff1aYaojie Lu\u7b49\u4eba\u5728ACL-2022\u4e2d\u63d0\u51fa\u4e86\u901a\u7528\u4fe1\u606f\u62bd\u53d6\u7edf\u4e00\u6846\u67b6 `UIE`\u3002\r\n\r\n+ \u8be5\u6846\u67b6\u5b9e\u73b0\u4e86\u5b9e\u4f53\u62bd\u53d6\u3001\u5173\u7cfb\u62bd\u53d6\u3001\u4e8b\u4ef6\u62bd\u53d6\u3001\u60c5\u611f\u5206\u6790\u7b49\u4efb\u52a1\u7684\u7edf\u4e00\u5efa\u6a21\uff0c\u5e76\u4f7f\u5f97\u4e0d\u540c\u4efb\u52a1\u95f4\u5177\u5907\u826f\u597d\u7684\u8fc1\u79fb\u548c\u6cdb\u5316\u80fd\u529b\u3002\r\n\r\n+ \u4e3a\u4e86\u65b9\u4fbf\u5927\u5bb6\u4f7f\u7528 `UIE` \u7684\u5f3a\u5927\u80fd\u529b\uff0c[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)\u501f\u9274\u8be5\u8bba\u6587\u7684\u65b9\u6cd5\uff0c\u57fa\u4e8e `ERNIE 3.0` \u77e5\u8bc6\u589e\u5f3a\u9884\u8bad\u7ec3\u6a21\u578b\uff0c\u8bad\u7ec3\u5e76\u5f00\u6e90\u4e86\u9996\u4e2a\u4e2d\u6587\u901a\u7528\u4fe1\u606f\u62bd\u53d6\u6a21\u578b `UIE`\u3002\r\n\r\n+ \u8be5\u6a21\u578b\u53ef\u4ee5\u652f\u6301\u4e0d\u9650\u5b9a\u884c\u4e1a\u9886\u57df\u548c\u62bd\u53d6\u76ee\u6807\u7684\u5173\u952e\u4fe1\u606f\u62bd\u53d6\uff0c\u5b9e\u73b0\u96f6\u6837\u672c\u5feb\u901f\u51b7\u542f\u52a8\uff0c\u5e76\u5177\u5907\u4f18\u79c0\u7684\u5c0f\u6837\u672c\u5fae\u8c03\u80fd\u529b\uff0c\u5feb\u901f\u9002\u914d\u7279\u5b9a\u7684\u62bd\u53d6\u76ee\u6807\u3002\r\n\r\n![uie](./images/uie.png)\r\n\r\n<details>\r\n<summary>\ud83d\udc49 \u547d\u540d\u5b9e\u4f53\u8bc6\u522b</summary>\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.uie import UIEPredictor\r\n\r\n# \u5b9e\u4f53\u8bc6\u522b\r\nschema = ['\u65f6\u95f4', '\u9009\u624b', '\u8d5b\u4e8b\u540d\u79f0'] \r\n# uie-base\u6a21\u578b\u5df2\u4e0a\u4f20\u81f3huggingface\uff0c\u53ef\u81ea\u52a8\u4e0b\u8f7d\uff0c\u5176\u4ed6\u6a21\u578b\u53ea\u9700\u63d0\u4f9b\u6a21\u578b\u540d\u79f0\u5c06\u81ea\u52a8\u8fdb\u884c\u8f6c\u6362\r\nuie = UIEPredictor(\"xusenlin/uie-base\", schema=schema)\r\npprint(uie(\"2\u67088\u65e5\u4e0a\u5348\u5317\u4eac\u51ac\u5965\u4f1a\u81ea\u7531\u5f0f\u6ed1\u96ea\u5973\u5b50\u5927\u8df3\u53f0\u51b3\u8d5b\u4e2d\u4e2d\u56fd\u9009\u624b\u8c37\u7231\u51cc\u4ee5188.25\u5206\u83b7\u5f97\u91d1\u724c\uff01\")) # Better print results using pprint\r\n```\r\n\u8f93\u51fa\uff1a\r\n```text\r\n[{'\u65f6\u95f4': [{'end': 6,\r\n 'probability': 0.9857378532924486,\r\n 'start': 0,\r\n 'text': '2\u67088\u65e5\u4e0a\u5348'}],\r\n '\u8d5b\u4e8b\u540d\u79f0': [{'end': 23,\r\n 'probability': 0.8503089953268272,\r\n 'start': 6,\r\n 'text': '\u5317\u4eac\u51ac\u5965\u4f1a\u81ea\u7531\u5f0f\u6ed1\u96ea\u5973\u5b50\u5927\u8df3\u53f0\u51b3\u8d5b'}],\r\n '\u9009\u624b': [{'end': 31,\r\n 'probability': 0.8981548639781138,\r\n 'start': 28,\r\n 'text': '\u8c37\u7231\u51cc'}]}]\r\n```\r\n</details>\r\n\r\n<details>\r\n<summary>\ud83d\udc49 \u5b9e\u4f53\u5173\u7cfb\u62bd\u53d6</summary>\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.uie import UIEPredictor\r\n\r\n# \u5173\u7cfb\u62bd\u53d6\r\nschema = {'\u7ade\u8d5b\u540d\u79f0': ['\u4e3b\u529e\u65b9', '\u627f\u529e\u65b9', '\u5df2\u4e3e\u529e\u6b21\u6570']}\r\n# uie-base\u6a21\u578b\u5df2\u4e0a\u4f20\u81f3huggingface\uff0c\u53ef\u81ea\u52a8\u4e0b\u8f7d\uff0c\u5176\u4ed6\u6a21\u578b\u53ea\u9700\u63d0\u4f9b\u6a21\u578b\u540d\u79f0\u5c06\u81ea\u52a8\u8fdb\u884c\u8f6c\u6362\r\nuie = UIEPredictor(\"xusenlin/uie-base\", schema=schema)\r\npprint(uie(\"2022\u8bed\u8a00\u4e0e\u667a\u80fd\u6280\u672f\u7ade\u8d5b\u7531\u4e2d\u56fd\u4e2d\u6587\u4fe1\u606f\u5b66\u4f1a\u548c\u4e2d\u56fd\u8ba1\u7b97\u673a\u5b66\u4f1a\u8054\u5408\u4e3b\u529e\uff0c\u767e\u5ea6\u516c\u53f8\u3001\u4e2d\u56fd\u4e2d\u6587\u4fe1\u606f\u5b66\u4f1a\u8bc4\u6d4b\u5de5\u4f5c\u59d4\u5458\u4f1a\u548c\u4e2d\u56fd\u8ba1\u7b97\u673a\u5b66\u4f1a\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e13\u59d4\u4f1a\u627f\u529e\uff0c\u5df2\u8fde\u7eed\u4e3e\u529e4\u5c4a\uff0c\u6210\u4e3a\u5168\u7403\u6700\u70ed\u95e8\u7684\u4e2d\u6587NLP\u8d5b\u4e8b\u4e4b\u4e00\u3002\")) # Better print results using pprint\r\n```\r\n\u8f93\u51fa\uff1a\r\n```text\r\n[{'\u7ade\u8d5b\u540d\u79f0': [{'end': 13,\r\n 'probability': 0.7825402622754041,\r\n 'relations': {'\u4e3b\u529e\u65b9': [{'end': 22,\r\n 'probability': 0.8421710521379353,\r\n 'start': 14,\r\n 'text': '\u4e2d\u56fd\u4e2d\u6587\u4fe1\u606f\u5b66\u4f1a'},\r\n {'end': 30,\r\n 'probability': 0.7580801847701935,\r\n 'start': 23,\r\n 'text': '\u4e2d\u56fd\u8ba1\u7b97\u673a\u5b66\u4f1a'}],\r\n '\u5df2\u4e3e\u529e\u6b21\u6570': [{'end': 82,\r\n 'probability': 0.4671295049136148,\r\n 'start': 80,\r\n 'text': '4\u5c4a'}],\r\n '\u627f\u529e\u65b9': [{'end': 39,\r\n 'probability': 0.8292706618236352,\r\n 'start': 35,\r\n 'text': '\u767e\u5ea6\u516c\u53f8'},\r\n {'end': 72,\r\n 'probability': 0.6193477885474685,\r\n 'start': 56,\r\n 'text': '\u4e2d\u56fd\u8ba1\u7b97\u673a\u5b66\u4f1a\u81ea\u7136\u8bed\u8a00\u5904\u7406\u4e13\u59d4\u4f1a'},\r\n {'end': 55,\r\n 'probability': 0.7000497331473241,\r\n 'start': 40,\r\n 'text': '\u4e2d\u56fd\u4e2d\u6587\u4fe1\u606f\u5b66\u4f1a\u8bc4\u6d4b\u5de5\u4f5c\u59d4\u5458\u4f1a'}]},\r\n 'start': 0,\r\n 'text': '2022\u8bed\u8a00\u4e0e\u667a\u80fd\u6280\u672f\u7ade\u8d5b'}]}]\r\n```\r\n</details>\r\n\r\n\r\n<details>\r\n<summary>\ud83d\udc49 \u4e8b\u4ef6\u62bd\u53d6</summary>\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.uie import UIEPredictor\r\n\r\n# \u4e8b\u4ef6\u62bd\u53d6\r\nschema = {\"\u5730\u9707\u89e6\u53d1\u8bcd\": [\"\u5730\u9707\u5f3a\u5ea6\", \"\u65f6\u95f4\", \"\u9707\u4e2d\u4f4d\u7f6e\", \"\u9707\u6e90\u6df1\u5ea6\"]}\r\n# uie-base\u6a21\u578b\u5df2\u4e0a\u4f20\u81f3huggingface\uff0c\u53ef\u81ea\u52a8\u4e0b\u8f7d\uff0c\u5176\u4ed6\u6a21\u578b\u53ea\u9700\u63d0\u4f9b\u6a21\u578b\u540d\u79f0\u5c06\u81ea\u52a8\u8fdb\u884c\u8f6c\u6362\r\nuie = UIEPredictor(\"xusenlin/uie-base\", schema=schema)\r\npprint(uie(\"\u4e2d\u56fd\u5730\u9707\u53f0\u7f51\u6b63\u5f0f\u6d4b\u5b9a\uff1a5\u670816\u65e506\u65f608\u5206\u5728\u4e91\u5357\u4e34\u6ca7\u5e02\u51e4\u5e86\u53bf(\u5317\u7eac24.34\u5ea6\uff0c\u4e1c\u7ecf99.98\u5ea6)\u53d1\u751f3.5\u7ea7\u5730\u9707\uff0c\u9707\u6e90\u6df1\u5ea610\u5343\u7c73\u3002\")) # Better print results using pprint\r\n```\r\n\u8f93\u51fa\uff1a\r\n```text\r\n[{'\u5730\u9707\u89e6\u53d1\u8bcd': {'end': 58,\r\n 'probability': 0.9977425932884216,\r\n 'relation': {'\u5730\u9707\u5f3a\u5ea6': [{'end': 56,\r\n 'probability': 0.9980800747871399,\r\n 'start': 52,\r\n 'text': '3.5\u7ea7'}],\r\n '\u65f6\u95f4': [{'end': 22,\r\n 'probability': 0.9853301644325256,\r\n 'start': 11,\r\n 'text': '5\u670816\u65e506\u65f608\u5206'}],\r\n '\u9707\u4e2d\u4f4d\u7f6e': [{'end': 50,\r\n 'probability': 0.7874020934104919,\r\n 'start': 23,\r\n 'text': '\u4e91\u5357\u4e34\u6ca7\u5e02\u51e4\u5e86\u53bf(\u5317\u7eac24.34\u5ea6\uff0c\u4e1c\u7ecf99.98\u5ea6)'}],\r\n '\u9707\u6e90\u6df1\u5ea6': [{'end': 67,\r\n 'probability': 0.9937973618507385,\r\n 'start': 63,\r\n 'text': '10\u5343\u7c73'}]},\r\n 'start': 56,\r\n 'text': '\u5730\u9707'}}]\r\n```\r\n</details>\r\n\r\n<details>\r\n<summary>\ud83d\udc49 \u8bc4\u8bba\u89c2\u70b9\u62bd\u53d6</summary>\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.uie import UIEPredictor\r\n\r\n# \u4e8b\u4ef6\u62bd\u53d6\r\nschema = {'\u8bc4\u4ef7\u7ef4\u5ea6': ['\u89c2\u70b9\u8bcd', '\u60c5\u611f\u503e\u5411[\u6b63\u5411\uff0c\u8d1f\u5411]']}\r\n# uie-base\u6a21\u578b\u5df2\u4e0a\u4f20\u81f3huggingface\uff0c\u53ef\u81ea\u52a8\u4e0b\u8f7d\uff0c\u5176\u4ed6\u6a21\u578b\u53ea\u9700\u63d0\u4f9b\u6a21\u578b\u540d\u79f0\u5c06\u81ea\u52a8\u8fdb\u884c\u8f6c\u6362\r\nuie = UIEPredictor(\"xusenlin/uie-base\", schema=schema)\r\npprint(uie(\"\u5e97\u9762\u5e72\u51c0\uff0c\u5f88\u6e05\u9759\uff0c\u670d\u52a1\u5458\u670d\u52a1\u70ed\u60c5\uff0c\u6027\u4ef7\u6bd4\u5f88\u9ad8\uff0c\u53d1\u73b0\u6536\u94f6\u53f0\u6709\u6392\u961f\")) # Better print results using pprint\r\n```\r\n\u8f93\u51fa\uff1a\r\n```text\r\n[{'\u8bc4\u4ef7\u7ef4\u5ea6': [{'end': 20,\r\n 'probability': 0.9817040258681473,\r\n 'relations': {'\u60c5\u611f\u503e\u5411[\u6b63\u5411\uff0c\u8d1f\u5411]': [{'probability': 0.9966142505350533,\r\n 'text': '\u6b63\u5411'}],\r\n '\u89c2\u70b9\u8bcd': [{'end': 22,\r\n 'probability': 0.957396472711558,\r\n 'start': 21,\r\n 'text': '\u9ad8'}]},\r\n 'start': 17,\r\n 'text': '\u6027\u4ef7\u6bd4'},\r\n {'end': 2,\r\n 'probability': 0.9696849569741168,\r\n 'relations': {'\u60c5\u611f\u503e\u5411[\u6b63\u5411\uff0c\u8d1f\u5411]': [{'probability': 0.9982153274927796,\r\n 'text': '\u6b63\u5411'}],\r\n '\u89c2\u70b9\u8bcd': [{'end': 4,\r\n 'probability': 0.9945318044652538,\r\n 'start': 2,\r\n 'text': '\u5e72\u51c0'}]},\r\n 'start': 0,\r\n 'text': '\u5e97\u9762'}]}]\r\n```\r\n</details>\r\n\r\n\r\n<details>\r\n<summary>\ud83d\udc49 \u60c5\u611f\u5206\u7c7b</summary>\r\n\r\n\r\n```python\r\nfrom pprint import pprint\r\nfrom lightningnlp.task.uie import UIEPredictor\r\n\r\n# \u4e8b\u4ef6\u62bd\u53d6\r\nschema = '\u60c5\u611f\u503e\u5411[\u6b63\u5411\uff0c\u8d1f\u5411]'\r\n# uie-base\u6a21\u578b\u5df2\u4e0a\u4f20\u81f3huggingface\uff0c\u53ef\u81ea\u52a8\u4e0b\u8f7d\uff0c\u5176\u4ed6\u6a21\u578b\u53ea\u9700\u63d0\u4f9b\u6a21\u578b\u540d\u79f0\u5c06\u81ea\u52a8\u8fdb\u884c\u8f6c\u6362\r\nuie = UIEPredictor(\"xusenlin/uie-base\", schema=schema)\r\npprint(uie(\"\u8fd9\u4e2a\u4ea7\u54c1\u7528\u8d77\u6765\u771f\u7684\u5f88\u6d41\u7545\uff0c\u6211\u975e\u5e38\u559c\u6b22\")) # Better print results using pprint\r\n```\r\n\u8f93\u51fa\uff1a\r\n```text\r\n[{'\u60c5\u611f\u503e\u5411[\u6b63\u5411\uff0c\u8d1f\u5411]': {'end': 0,\r\n 'probability': 0.9990023970603943,\r\n 'start': 0,\r\n 'text': '\u6b63\u5411'}}]\r\n```\r\n</details>\r\n\r\n\r\n## Citation\r\n\u5982\u679c `LightningNLP` \u5bf9\u60a8\u7684\u7814\u7a76\u6709\u5e2e\u52a9\uff0c\u6b22\u8fce\u5f15\u7528\r\n\r\n```text\r\n@misc{=lightningnlp,\r\n title={LightningNLP: An Easy-to-use NLP Library},\r\n author={senlin xu},\r\n howpublished = {\\url{https://github.com/xusenlinzy/lightningblocks}},\r\n year={2022}\r\n}\r\n```\r\n\r\n## Acknowledge\r\n\r\n\u6211\u4eec\u501f\u9274\u4e86[`Lightning-transformers`](https://github.com/Lightning-AI/lightning-transformers) \u5173\u4e8e\u6a21\u578b\u4f7f\u7528\u7684\u4f18\u79c0\u8bbe\u8ba1\uff0c\u5728\u6b64\u5bf9`Lightning-transformers` \u4f5c\u8005\u53ca\u5176\u5f00\u6e90\u793e\u533a\u8868\u793a\u611f\u8c22\u3002\r\n\r\n",
"bugtrack_url": null,
"license": "MIT Licence",
"summary": "Pytorch-lightning Code Blocks for NLP",
"version": "1.0.2",
"split_keywords": [],
"urls": [
{
"comment_text": "",
"digests": {
"md5": "99c6b5908b1e7e374a4ee27c56930837",
"sha256": "26ad129026620edd4410bf95685af04a3b4299a0e6e0e19882e19049d3fbf3d4"
},
"downloads": -1,
"filename": "lightningnlp-1.0.2.tar.gz",
"has_sig": false,
"md5_digest": "99c6b5908b1e7e374a4ee27c56930837",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.7",
"size": 230256,
"upload_time": "2022-12-14T07:30:04",
"upload_time_iso_8601": "2022-12-14T07:30:04.962891Z",
"url": "https://files.pythonhosted.org/packages/ad/ae/4cbae6316dcb022c204245c1e3e50f321ec575b943532a9bf8bb9c7224d6/lightningnlp-1.0.2.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2022-12-14 07:30:04",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "xusenlinzy",
"github_project": "lightningblocks",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [],
"lcname": "lightningnlp"
}