<div align="center">
<a href="https://github.com/shibing624/pytextclassifier">
<img src="https://github.com/shibing624/pytextclassifier/blob/master/docs/logo.png" alt="Logo" height="156">
</a>
</div>
-----------------
# PyTextClassifier: Python Text Classifier
[![PyPI version](https://badge.fury.io/py/pytextclassifier.svg)](https://badge.fury.io/py/pytextclassifier)
[![Downloads](https://static.pepy.tech/badge/pytextclassifier)](https://pepy.tech/project/pytextclassifier)
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![python_vesion](https://img.shields.io/badge/Python-3.5%2B-green.svg)](requirements.txt)
[![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues)
[![Wechat Group](https://img.shields.io/badge/wechat-group-green.svg?logo=wechat)](#Contact)
## Introduction
PyTextClassifier: Python Text Classifier. It can be applied to the fields of sentiment polarity analysis, text risk classification and so on,
and it supports multiple classification algorithms and clustering algorithms.
**pytextclassifier** is a python Open Source Toolkit for text classification. The goal is to implement
text analysis algorithm, so to achieve the use in the production environment.
文本分类器,提供多种文本分类和聚类算法,支持句子和文档级的文本分类任务,支持二分类、多分类、多标签分类、多层级分类和Kmeans聚类,开箱即用。python3开发。
**Guide**
- [Feature](#Feature)
- [Install](#install)
- [Usage](#usage)
- [Dataset](#Dataset)
- [Contact](#Contact)
- [Citation](#Citation)
- [Reference](#reference)
## Feature
**pytextclassifier** has the characteristics
of clear algorithm, high performance and customizable corpus.
Functions:
### Classifier
- [x] LogisticRegression
- [x] Random Forest
- [x] Decision Tree
- [x] K-Nearest Neighbours
- [x] Naive bayes
- [x] Xgboost
- [x] Support Vector Machine(SVM)
- [x] TextCNN
- [x] TextRNN
- [x] Fasttext
- [x] BERT
### Cluster
- [x] MiniBatchKmeans
While providing rich functions, **pytextclassifier** internal modules adhere to low coupling, model adherence to inert loading, dictionary publication, and easy to use.
## Install
- Requirements and Installation
```
pip3 install torch # conda install pytorch
pip3 install pytextclassifier
```
or
```
git clone https://github.com/shibing624/pytextclassifier.git
cd pytextclassifier
python3 setup.py install
```
## Usage
### Text Classifier
### English Text Classifier
Including model training, saving, predict, evaluate, for example [examples/lr_en_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_en_classification_demo.py):
```python
import sys
sys.path.append('..')
from pytextclassifier import ClassicClassifier
if __name__ == '__main__':
m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
# ClassicClassifier support model_name:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
print(m)
data = [
('education', 'Student debt to cost Britain billions within decades'),
('education', 'Chinese education for TV experiment'),
('sports', 'Middle East and Asia boost investment in top level sports'),
('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')
]
# train and save best model
m.train(data)
# load best model from model_dir
m.load_model()
predict_label, predict_proba = m.predict([
'Abbott government spends $8 million on higher education media blitz'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
test_data = [
('education', 'Abbott government spends $8 million on higher education media blitz'),
('sports', 'Middle East and Asia boost investment in top level sports'),
]
acc_score = m.evaluate_model(test_data)
print(f'acc_score: {acc_score}')
```
output:
```
ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
predict_label: ['education'], predict_proba: [0.5378236358492112]
acc_score: 1.0
```
### Chinese Text Classifier(中文文本分类)
Text classification compatible with Chinese and English corpora.
example [examples/lr_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_classification_demo.py)
```python
import sys
sys.path.append('..')
from pytextclassifier import ClassicClassifier
if __name__ == '__main__':
m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
# 经典分类方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboost
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜'),
]
m.train(data)
print(m)
# load best model from model_dir
m.load_model()
predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
'意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
test_data = [
('education', '福建春季公务员考试报名18日截止 2月6日考试'),
('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
]
acc_score = m.evaluate_model(test_data)
print(f'acc_score: {acc_score}') # 1.0
#### train model with 1w data
print('-' * 42)
m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')
data_file = 'thucnews_train_1w.txt'
m.train(data_file)
m.load_model()
predict_label, predict_proba = m.predict(
['顺义北京苏活88平米起精装房在售',
'美EB-5项目“15日快速移民”将推迟'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
```
output:
```
ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
predict_label: ['education' 'sports'], predict_proba: [0.5, 0.598941806741534]
acc_score: 1.0
------------------------------------------
predict_label: ['realty' 'education'], predict_proba: [0.7302956923617372, 0.2565005445322923]
```
### Visual Feature Importance
Show feature weights of model, and prediction word weight, for example [examples/visual_feature_importance.ipynb](https://github.com/shibing624/pytextclassifier/blob/master/examples/visual_feature_importance.ipynb)
```python
import sys
sys.path.append('..')
from pytextclassifier import ClassicClassifier
import jieba
tc = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜')
]
tc.train(data)
import eli5
infer_data = ['高考指导托福语法技巧国际认可',
'意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜']
eli5.show_weights(tc.model, vec=tc.feature)
seg_infer_data = [' '.join(jieba.lcut(i)) for i in infer_data]
eli5.show_prediction(tc.model, seg_infer_data[0], vec=tc.feature,
target_names=['education', 'sports'])
```
output:
![img.png](docs/img.png)
### Deep Classification model
本项目支持以下深度分类模型:FastText、TextCNN、TextRNN、Bert模型,`import`模型对应的方法来调用:
```python
from pytextclassifier import FastTextClassifier, TextCNNClassifier, TextRNNClassifier, BertClassifier
```
下面以FastText模型为示例,其他模型的使用方法类似。
### FastText 模型
训练和预测`FastText`模型示例[examples/fasttext_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/fasttext_classification_demo.py)
```python
import sys
sys.path.append('..')
from pytextclassifier import FastTextClassifier, load_data
if __name__ == '__main__':
m = FastTextClassifier(output_dir='models/fasttext-toy')
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败保持连胜'),
]
m.train(data, num_epochs=3)
print(m)
# load trained best model
m.load_model()
predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
'意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
test_data = [
('education', '福建春季公务员考试报名18日截止 2月6日考试'),
('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
]
acc_score = m.evaluate_model(test_data)
print(f'acc_score: {acc_score}') # 1.0
#### train model with 1w data
print('-' * 42)
data_file = 'thucnews_train_1w.txt'
m = FastTextClassifier(output_dir='models/fasttext')
m.train(data_file, names=('labels', 'text'), num_epochs=3)
# load best trained model from model_dir
m.load_model()
predict_label, predict_proba = m.predict(
['顺义北京苏活88平米起精装房在售',
'美EB-5项目“15日快速移民”将推迟']
)
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
x, y, df = load_data(data_file)
test_data = df[:100]
acc_score = m.evaluate_model(test_data)
print(f'acc_score: {acc_score}')
```
### BERT 类模型
#### 多分类模型
训练和预测`BERT`多分类模型,示例[examples/bert_classification_zh_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_classification_zh_demo.py)
```python
import sys
sys.path.append('..')
from pytextclassifier import BertClassifier
if __name__ == '__main__':
m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2,
model_type='bert', model_name='bert-base-chinese', num_epochs=2)
# model_type: support 'bert', 'albert', 'roberta', 'xlnet'
# model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜'),
]
m.train(data)
print(m)
# load trained best model from model_dir
m.load_model()
predict_label, predict_proba = m.predict(['福建春季公务员考试报名18日截止 2月6日考试',
'意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
test_data = [
('education', '福建春季公务员考试报名18日截止 2月6日考试'),
('sports', '意甲首轮补赛交战记录:米兰客场8战不败国米10年连胜'),
]
acc_score = m.evaluate_model(test_data)
print(f'acc_score: {acc_score}')
# train model with 1w data file and 10 classes
print('-' * 42)
m = BertClassifier(output_dir='models/bert-chinese', num_classes=10,
model_type='bert', model_name='bert-base-chinese', num_epochs=2,
args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, })
data_file = 'thucnews_train_1w.txt'
# 如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
m.train(data_file, test_size=0, names=('labels', 'text'))
m.load_model()
predict_label, predict_proba = m.predict(
['顺义北京苏活88平米起精装房在售',
'美EB-5项目“15日快速移民”将推迟',
'恒生AH溢指收平 A股对H股折价1.95%'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
```
PS:如果训练数据超过百万条,建议使用lazy_loading模式,减少内存占用
#### 多标签分类模型
分类可以分为多分类和多标签分类。多分类的标签是排他的,而多标签分类的所有标签是不排他的。
多标签分类比较直观的理解是,一个样本可以同时拥有几个类别标签,
比如一首歌的标签可以是流行、轻快,一部电影的标签可以是动作、喜剧、搞笑等,这都是多标签分类的情况。
训练和预测`BERT`多标签分类模型,示例[examples/bert_multilabel_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_multilabel_classification_zh_demo.py)
```python
import sys
import pandas as pd
sys.path.append('..')
from pytextclassifier import BertClassifier
def load_jd_data(file_path):
"""
Load jd data from file.
@param file_path:
format: content,其他,互联互通,产品功耗,滑轮提手,声音,APP操控性,呼吸灯,外观,底座,制热范围,遥控器电池,味道,制热效果,衣物烘干,体积大小
@return:
"""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line.startswith('#'):
continue
if not line:
continue
terms = line.split(',')
if len(terms) != 16:
continue
val = [int(i) for i in terms[1:]]
data.append([terms[0], val])
return data
if __name__ == '__main__':
# model_type: support 'bert', 'albert', 'roberta', 'xlnet'
# model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...
m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15,
model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)
# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.
train_data = [
["一个小时房间仍然没暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
["耗电情况:这个没有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
]
data = load_jd_data('multilabel_jd_comments.csv')
train_data.extend(data)
print(train_data[:5])
train_df = pd.DataFrame(train_data, columns=["text", "labels"])
print(train_df.head())
m.train(train_df)
print(m)
# Evaluate the model
acc_score = m.evaluate_model(train_df[:20])
print(f'acc_score: {acc_score}')
# load trained best model from model_dir
m.load_model()
predict_label, predict_proba = m.predict(['一个小时房间仍然没暖和', '耗电情况:这个没有注意'])
print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
```
#### 多层级分类模型
**多层级标签分类任务**,如行业分类(一级行业下分二级子行业,再分三级)、产品分类,可以使用多标签分类模型,将多层级标签转换为多标签形式,
示例[examples/bert_hierarchical_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_hierarchical_classification_zh_demo.py)
#### ONNX推理加速
支持将训练好的模型导出为ONNX格式,以便推理加速,或者在其他环境如C++部署模型调用。
- GPU环境下导出ONNX模型,用ONNX模型推理,可以获得10倍以上的推理加速,需要安装`onnxruntime-gpu`库:`pip install onnxruntime-gpu`
- CPU环境下导出ONNX模型,用ONNX模型推理,可以获得6倍以上的推理加速,需要安装`onnxruntime`库:`pip install onnxruntime`
示例[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py)
```python
import os
import shutil
import sys
import time
import torch
sys.path.append('..')
from pytextclassifier import BertClassifier
m = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,
model_type='bert', model_name='bert-base-chinese', num_epochs=1)
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('education', '公务员考虑越来越吃香,这是怎么回事1?'),
('education', '公务员考虑越来越吃香,这是怎么回事2?'),
('education', '公务员考虑越来越吃香,这是怎么回事3?'),
('education', '公务员考虑越来越吃香,这是怎么回事4?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜1'),
('sports', '米兰客场8战不败国米10年连胜2'),
('sports', '米兰客场8战不败国米10年连胜3'),
('sports', '米兰客场8战不败国米10年连胜4'),
('sports', '米兰客场8战不败国米10年连胜5'),
]
m.train(data * 10)
m.load_model()
samples = ['名师指导托福语法技巧',
'米兰客场8战不败',
'恒生AH溢指收平 A股对H股折价1.95%'] * 100
start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_bert = end_time - start_time
print(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')
# convert to onnx, and load onnx model to predict, speed up 10x
save_onnx_dir = 'models/bert-chinese-v1/onnx'
m.model.convert_to_onnx(save_onnx_dir)
# copy label_vocab.json to save_onnx_dir
if os.path.exists(m.label_vocab_path):
shutil.copy(m.label_vocab_path, save_onnx_dir)
# Manually delete the model and clear CUDA cache
del m
torch.cuda.empty_cache()
m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,
args={"onnx": True})
m.load_model()
start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_onnx = end_time - start_time
print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
```
## Evaluation
### Dataset
1. THUCNews中文文本数据集(1.56GB):官方[下载地址](http://thuctc.thunlp.org/),抽样了10万条THUCNews中文文本10分类数据集(6MB),地址:[examples/thucnews_train_10w.txt](https://github.com/shibing624/pytextclassifier/blob/master/examples/thucnews_train_10w.txt)。
2. TNEWS今日头条中文新闻(短文本)分类 Short Text Classificaiton for News,该数据集(5.1MB)来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等,地址:[tnews_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip)
### Evaluation Result
在THUCNews中文文本10分类数据集(6MB)上评估,模型在测试集(test)评测效果如下:
模型|acc|说明
--|--|--
LR|0.8803|逻辑回归Logistics Regression
TextCNN|0.8809|Kim 2014 经典的CNN文本分类
TextRNN_Att|0.9022|BiLSTM+Attention
FastText|0.9177|bow+bigram+trigram, 效果出奇的好
DPCNN|0.9125|深层金字塔CNN
Transformer|0.8991|效果较差
BERT-base|**0.9483**|bert + fc
ERNIE|0.9461|比bert略差
在中文新闻短文本分类数据集TNEWS上评估,模型在开发集(dev)评测效果如下:
模型|acc|说明
--|--|--
BERT-base|**0.5660**|本项目实现
BERT-base|0.5609|CLUE Benchmark Leaderboard结果 [CLUEbenchmark](https://github.com/CLUEbenchmark/CLUE)
- 以上结果均为分类的准确率(accuracy)结果
- THUCNews数据集评测结果可以基于`examples/thucnews_train_10w.txt`数据用`examples`下的各模型demo复现
- TNEWS数据集评测结果可以下载TNEWS数据集,运行`examples/bert_classification_tnews_demo.py`复现
### 命令行调用
提供分类模型命令行调用脚本,文件树:
```bash
pytextclassifier
├── bert_classifier.py
├── fasttext_classifier.py
├── classic_classifier.py
├── textcnn_classifier.py
└── textrnn_classifier.py
```
每个文件对应一个模型方法,各模型完全独立,可以直接运行,也方便修改,支持通过`argparse` 修改`--data_path`等参数。
直接在终端调用fasttext模型训练:
```bash
python -m pytextclassifier.fasttext_classifier -h
```
## Text Cluster
Text clustering, for example [examples/cluster_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/cluster_demo.py)
```python
import sys
sys.path.append('..')
from pytextclassifier.textcluster import TextCluster
if __name__ == '__main__':
m = TextCluster(output_dir='models/cluster-toy', n_clusters=2)
print(m)
data = [
'Student debt to cost Britain billions within decades',
'Chinese education for TV experiment',
'Abbott government spends $8 million on higher education',
'Middle East and Asia boost investment in top level sports',
'Summit Series look launches HBO Canada sports doc series: Mudhar'
]
m.train(data)
m.load_model()
r = m.predict(['Abbott government spends $8 million on higher education media blitz',
'Middle East and Asia boost investment in top level sports'])
print(r)
########### load chinese train data from 1w data file
from sklearn.feature_extraction.text import TfidfVectorizer
tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10)
data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1)
feature, labels = tcluster.train(data[:5000])
tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png')
r = tcluster.predict(data[:30])
print(r)
```
output:
```
TextCluster instance (MiniBatchKMeans(n_clusters=2, n_init=10), <pytextclassifier.utils.tokenizer.Tokenizer object at 0x7f80bd4682b0>, TfidfVectorizer(ngram_range=(1, 2)))
[1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 9 1 1 8 1 1 9 1]
```
clustering plot image:
![cluster_image](https://github.com/shibing624/pytextclassifier/blob/master/docs/cluster_train_seg_samples.png)
## Contact
- Issue(建议):[![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues)
- 邮件我:xuming: xuming624@qq.com
- 微信我:加我*微信号:xuming624*, 进Python-NLP交流群,备注:*姓名-公司名-NLP*
<img src="docs/wechat.jpeg" width="200" />
## Citation
如果你在研究中使用了pytextclassifier,请按如下格式引用:
APA:
```latex
Xu, M. Pytextclassifier: Text classifier toolkit for NLP (Version 1.2.0) [Computer software]. https://github.com/shibing624/pytextclassifier
```
BibTeX:
```latex
@misc{Pytextclassifier,
title={Pytextclassifier: Text classifier toolkit for NLP},
author={Xu Ming},
year={2022},
howpublished={\url{https://github.com/shibing624/pytextclassifier}},
}
```
## License
授权协议为 [The Apache License 2.0](LICENSE),可免费用做商业用途。请在产品说明中附加**pytextclassifier**的链接和授权协议。
## Contribute
项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
- 在`tests`添加相应的单元测试
- 使用`python setup.py test`来运行所有单元测试,确保所有单测都是通过的
之后即可提交PR。
Raw data
{
"_id": null,
"home_page": "https://github.com/shibing624/pytextclassifier",
"name": "pytextclassifier",
"maintainer": null,
"docs_url": null,
"requires_python": null,
"maintainer_email": null,
"keywords": "pytextclassifier, textclassifier, classifier, textclassification",
"author": "XuMing",
"author_email": "xuming624@qq.com",
"download_url": "https://files.pythonhosted.org/packages/42/3c/aa93299704c9a7c17e96f6bb060b4ceaff56932f931d26610a5c9aef4316/pytextclassifier-1.4.0.tar.gz",
"platform": null,
"description": "<div align=\"center\">\n <a href=\"https://github.com/shibing624/pytextclassifier\">\n <img src=\"https://github.com/shibing624/pytextclassifier/blob/master/docs/logo.png\" alt=\"Logo\" height=\"156\">\n </a>\n</div>\n\n-----------------\n\n# PyTextClassifier: Python Text Classifier\n[![PyPI version](https://badge.fury.io/py/pytextclassifier.svg)](https://badge.fury.io/py/pytextclassifier)\n[![Downloads](https://static.pepy.tech/badge/pytextclassifier)](https://pepy.tech/project/pytextclassifier)\n[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)\n[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)\n[![python_vesion](https://img.shields.io/badge/Python-3.5%2B-green.svg)](requirements.txt)\n[![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues)\n[![Wechat Group](https://img.shields.io/badge/wechat-group-green.svg?logo=wechat)](#Contact)\n\n\n## Introduction\nPyTextClassifier: Python Text Classifier. It can be applied to the fields of sentiment polarity analysis, text risk classification and so on,\nand it supports multiple classification algorithms and clustering algorithms.\n\n**pytextclassifier** is a python Open Source Toolkit for text classification. The goal is to implement\ntext analysis algorithm, so to achieve the use in the production environment.\n\n\u6587\u672c\u5206\u7c7b\u5668\uff0c\u63d0\u4f9b\u591a\u79cd\u6587\u672c\u5206\u7c7b\u548c\u805a\u7c7b\u7b97\u6cd5\uff0c\u652f\u6301\u53e5\u5b50\u548c\u6587\u6863\u7ea7\u7684\u6587\u672c\u5206\u7c7b\u4efb\u52a1\uff0c\u652f\u6301\u4e8c\u5206\u7c7b\u3001\u591a\u5206\u7c7b\u3001\u591a\u6807\u7b7e\u5206\u7c7b\u3001\u591a\u5c42\u7ea7\u5206\u7c7b\u548cKmeans\u805a\u7c7b\uff0c\u5f00\u7bb1\u5373\u7528\u3002python3\u5f00\u53d1\u3002\n\n**Guide**\n\n- [Feature](#Feature)\n- [Install](#install)\n- [Usage](#usage)\n- [Dataset](#Dataset)\n- [Contact](#Contact)\n- [Citation](#Citation)\n- [Reference](#reference)\n\n## Feature\n\n**pytextclassifier** has the characteristics\nof clear algorithm, high performance and customizable corpus.\n\nFunctions\uff1a\n### Classifier\n - [x] LogisticRegression\n - [x] Random Forest\n - [x] Decision Tree\n - [x] K-Nearest Neighbours\n - [x] Naive bayes\n - [x] Xgboost\n - [x] Support Vector Machine(SVM)\n - [x] TextCNN\n - [x] TextRNN\n - [x] Fasttext\n - [x] BERT\n\n### Cluster\n - [x] MiniBatchKmeans\n\nWhile providing rich functions, **pytextclassifier** internal modules adhere to low coupling, model adherence to inert loading, dictionary publication, and easy to use.\n\n## Install\n\n- Requirements and Installation\n\n```\npip3 install torch # conda install pytorch\npip3 install pytextclassifier\n```\n\nor\n\n```\ngit clone https://github.com/shibing624/pytextclassifier.git\ncd pytextclassifier\npython3 setup.py install\n```\n\n\n## Usage\n### Text Classifier\n\n### English Text Classifier\n\nIncluding model training, saving, predict, evaluate, for example [examples/lr_en_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_en_classification_demo.py):\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier import ClassicClassifier\n\nif __name__ == '__main__':\n m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')\n # ClassicClassifier support model_name\uff1alr, random_forest, decision_tree, knn, bayes, svm, xgboost\n print(m)\n data = [\n ('education', 'Student debt to cost Britain billions within decades'),\n ('education', 'Chinese education for TV experiment'),\n ('sports', 'Middle East and Asia boost investment in top level sports'),\n ('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')\n ]\n # train and save best model\n m.train(data)\n # load best model from model_dir\n m.load_model()\n predict_label, predict_proba = m.predict([\n 'Abbott government spends $8 million on higher education media blitz'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n\n test_data = [\n ('education', 'Abbott government spends $8 million on higher education media blitz'),\n ('sports', 'Middle East and Asia boost investment in top level sports'),\n ]\n acc_score = m.evaluate_model(test_data)\n print(f'acc_score: {acc_score}')\n```\n\noutput:\n\n```\nClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)\npredict_label: ['education'], predict_proba: [0.5378236358492112]\nacc_score: 1.0\n```\n\n### Chinese Text Classifier(\u4e2d\u6587\u6587\u672c\u5206\u7c7b)\n\nText classification compatible with Chinese and English corpora.\n\nexample [examples/lr_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/lr_classification_demo.py)\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier import ClassicClassifier\n\nif __name__ == '__main__':\n m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')\n # \u7ecf\u5178\u5206\u7c7b\u65b9\u6cd5\uff0c\u652f\u6301\u7684\u6a21\u578b\u5305\u62ec\uff1alr, random_forest, decision_tree, knn, bayes, svm, xgboost\n data = [\n ('education', '\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\uff1a\u540d\u8bcd\u7684\u590d\u6570\u5f62\u5f0f'),\n ('education', '\u4e2d\u56fd\u9ad8\u8003\u6210\u7ee9\u6d77\u5916\u8ba4\u53ef \u662f\u201c\u72fc\u6765\u4e86\u201d\u5417\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b\uff1f'),\n ('sports', '\u56fe\u6587\uff1a\u6cd5\u7f51\u5b5f\u83f2\u5c14\u65af\u82e6\u6218\u8fdb16\u5f3a \u5b5f\u83f2\u5c14\u65af\u6012\u543c'),\n ('sports', '\u56db\u5ddd\u4e39\u68f1\u4e3e\u884c\u5168\u56fd\u957f\u8ddd\u767b\u5c71\u6311\u6218\u8d5b \u8fd1\u4e07\u4eba\u53c2\u4e0e'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'),\n ]\n m.train(data)\n print(m)\n # load best model from model_dir\n m.load_model()\n predict_label, predict_proba = m.predict(['\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5',\n '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n\n test_data = [\n ('education', '\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5'),\n ('sports', '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'),\n ]\n acc_score = m.evaluate_model(test_data)\n print(f'acc_score: {acc_score}') # 1.0\n\n #### train model with 1w data\n print('-' * 42)\n m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')\n data_file = 'thucnews_train_1w.txt'\n m.train(data_file)\n m.load_model()\n predict_label, predict_proba = m.predict(\n ['\u987a\u4e49\u5317\u4eac\u82cf\u6d3b88\u5e73\u7c73\u8d77\u7cbe\u88c5\u623f\u5728\u552e',\n '\u7f8eEB-5\u9879\u76ee\u201c15\u65e5\u5feb\u901f\u79fb\u6c11\u201d\u5c06\u63a8\u8fdf'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n```\n\noutput:\n\n```\nClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)\npredict_label: ['education' 'sports'], predict_proba: [0.5, 0.598941806741534]\nacc_score: 1.0\n------------------------------------------\npredict_label: ['realty' 'education'], predict_proba: [0.7302956923617372, 0.2565005445322923]\n```\n\n### Visual Feature Importance\n\nShow feature weights of model, and prediction word weight, for example [examples/visual_feature_importance.ipynb](https://github.com/shibing624/pytextclassifier/blob/master/examples/visual_feature_importance.ipynb)\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier import ClassicClassifier\nimport jieba\n\ntc = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')\ndata = [\n ('education', '\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\uff1a\u540d\u8bcd\u7684\u590d\u6570\u5f62\u5f0f'),\n ('education', '\u4e2d\u56fd\u9ad8\u8003\u6210\u7ee9\u6d77\u5916\u8ba4\u53ef \u662f\u201c\u72fc\u6765\u4e86\u201d\u5417\uff1f'),\n ('sports', '\u56fe\u6587\uff1a\u6cd5\u7f51\u5b5f\u83f2\u5c14\u65af\u82e6\u6218\u8fdb16\u5f3a \u5b5f\u83f2\u5c14\u65af\u6012\u543c'),\n ('sports', '\u56db\u5ddd\u4e39\u68f1\u4e3e\u884c\u5168\u56fd\u957f\u8ddd\u767b\u5c71\u6311\u6218\u8d5b \u8fd1\u4e07\u4eba\u53c2\u4e0e'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc')\n]\ntc.train(data)\nimport eli5\n\ninfer_data = ['\u9ad8\u8003\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\u56fd\u9645\u8ba4\u53ef',\n '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc']\neli5.show_weights(tc.model, vec=tc.feature)\nseg_infer_data = [' '.join(jieba.lcut(i)) for i in infer_data]\neli5.show_prediction(tc.model, seg_infer_data[0], vec=tc.feature,\n target_names=['education', 'sports'])\n```\n\noutput:\n\n![img.png](docs/img.png)\n\n### Deep Classification model\n\n\u672c\u9879\u76ee\u652f\u6301\u4ee5\u4e0b\u6df1\u5ea6\u5206\u7c7b\u6a21\u578b\uff1aFastText\u3001TextCNN\u3001TextRNN\u3001Bert\u6a21\u578b\uff0c`import`\u6a21\u578b\u5bf9\u5e94\u7684\u65b9\u6cd5\u6765\u8c03\u7528\uff1a\n```python\nfrom pytextclassifier import FastTextClassifier, TextCNNClassifier, TextRNNClassifier, BertClassifier\n```\n\n\u4e0b\u9762\u4ee5FastText\u6a21\u578b\u4e3a\u793a\u4f8b\uff0c\u5176\u4ed6\u6a21\u578b\u7684\u4f7f\u7528\u65b9\u6cd5\u7c7b\u4f3c\u3002\n\n### FastText \u6a21\u578b\n\n\u8bad\u7ec3\u548c\u9884\u6d4b`FastText`\u6a21\u578b\u793a\u4f8b[examples/fasttext_classification_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/fasttext_classification_demo.py)\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier import FastTextClassifier, load_data\n\nif __name__ == '__main__':\n m = FastTextClassifier(output_dir='models/fasttext-toy')\n data = [\n ('education', '\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\uff1a\u540d\u8bcd\u7684\u590d\u6570\u5f62\u5f0f'),\n ('education', '\u4e2d\u56fd\u9ad8\u8003\u6210\u7ee9\u6d77\u5916\u8ba4\u53ef \u662f\u201c\u72fc\u6765\u4e86\u201d\u5417\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b\uff1f'),\n ('sports', '\u56fe\u6587\uff1a\u6cd5\u7f51\u5b5f\u83f2\u5c14\u65af\u82e6\u6218\u8fdb16\u5f3a \u5b5f\u83f2\u5c14\u65af\u6012\u543c'),\n ('sports', '\u56db\u5ddd\u4e39\u68f1\u4e3e\u884c\u5168\u56fd\u957f\u8ddd\u767b\u5c71\u6311\u6218\u8d5b \u8fd1\u4e07\u4eba\u53c2\u4e0e'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u4fdd\u6301\u8fde\u80dc'),\n ]\n m.train(data, num_epochs=3)\n print(m)\n # load trained best model\n m.load_model()\n predict_label, predict_proba = m.predict(['\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5',\n '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n test_data = [\n ('education', '\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5'),\n ('sports', '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'),\n ]\n acc_score = m.evaluate_model(test_data)\n print(f'acc_score: {acc_score}') # 1.0\n\n #### train model with 1w data\n print('-' * 42)\n data_file = 'thucnews_train_1w.txt'\n m = FastTextClassifier(output_dir='models/fasttext')\n m.train(data_file, names=('labels', 'text'), num_epochs=3)\n # load best trained model from model_dir\n m.load_model()\n predict_label, predict_proba = m.predict(\n ['\u987a\u4e49\u5317\u4eac\u82cf\u6d3b88\u5e73\u7c73\u8d77\u7cbe\u88c5\u623f\u5728\u552e',\n '\u7f8eEB-5\u9879\u76ee\u201c15\u65e5\u5feb\u901f\u79fb\u6c11\u201d\u5c06\u63a8\u8fdf']\n )\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n x, y, df = load_data(data_file)\n test_data = df[:100]\n acc_score = m.evaluate_model(test_data)\n print(f'acc_score: {acc_score}')\n```\n\n### BERT \u7c7b\u6a21\u578b\n\n#### \u591a\u5206\u7c7b\u6a21\u578b\n\u8bad\u7ec3\u548c\u9884\u6d4b`BERT`\u591a\u5206\u7c7b\u6a21\u578b\uff0c\u793a\u4f8b[examples/bert_classification_zh_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_classification_zh_demo.py)\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier import BertClassifier\n\nif __name__ == '__main__':\n m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2,\n model_type='bert', model_name='bert-base-chinese', num_epochs=2)\n # model_type: support 'bert', 'albert', 'roberta', 'xlnet'\n # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...\n data = [\n ('education', '\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\uff1a\u540d\u8bcd\u7684\u590d\u6570\u5f62\u5f0f'),\n ('education', '\u4e2d\u56fd\u9ad8\u8003\u6210\u7ee9\u6d77\u5916\u8ba4\u53ef \u662f\u201c\u72fc\u6765\u4e86\u201d\u5417\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b\uff1f'),\n ('sports', '\u56fe\u6587\uff1a\u6cd5\u7f51\u5b5f\u83f2\u5c14\u65af\u82e6\u6218\u8fdb16\u5f3a \u5b5f\u83f2\u5c14\u65af\u6012\u543c'),\n ('sports', '\u56db\u5ddd\u4e39\u68f1\u4e3e\u884c\u5168\u56fd\u957f\u8ddd\u767b\u5c71\u6311\u6218\u8d5b \u8fd1\u4e07\u4eba\u53c2\u4e0e'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'),\n ]\n m.train(data)\n print(m)\n # load trained best model from model_dir\n m.load_model()\n predict_label, predict_proba = m.predict(['\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5',\n '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n\n test_data = [\n ('education', '\u798f\u5efa\u6625\u5b63\u516c\u52a1\u5458\u8003\u8bd5\u62a5\u540d18\u65e5\u622a\u6b62 2\u67086\u65e5\u8003\u8bd5'),\n ('sports', '\u610f\u7532\u9996\u8f6e\u8865\u8d5b\u4ea4\u6218\u8bb0\u5f55:\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc'),\n ]\n acc_score = m.evaluate_model(test_data)\n print(f'acc_score: {acc_score}')\n\n # train model with 1w data file and 10 classes\n print('-' * 42)\n m = BertClassifier(output_dir='models/bert-chinese', num_classes=10,\n model_type='bert', model_name='bert-base-chinese', num_epochs=2,\n args={\"no_cache\": True, \"lazy_loading\": True, \"lazy_text_column\": 1, \"lazy_labels_column\": 0, })\n data_file = 'thucnews_train_1w.txt'\n # \u5982\u679c\u8bad\u7ec3\u6570\u636e\u8d85\u8fc7\u767e\u4e07\u6761\uff0c\u5efa\u8bae\u4f7f\u7528lazy_loading\u6a21\u5f0f\uff0c\u51cf\u5c11\u5185\u5b58\u5360\u7528\n m.train(data_file, test_size=0, names=('labels', 'text'))\n m.load_model()\n predict_label, predict_proba = m.predict(\n ['\u987a\u4e49\u5317\u4eac\u82cf\u6d3b88\u5e73\u7c73\u8d77\u7cbe\u88c5\u623f\u5728\u552e',\n '\u7f8eEB-5\u9879\u76ee\u201c15\u65e5\u5feb\u901f\u79fb\u6c11\u201d\u5c06\u63a8\u8fdf',\n '\u6052\u751fAH\u6ea2\u6307\u6536\u5e73 A\u80a1\u5bf9H\u80a1\u6298\u4ef71.95%'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n```\nPS\uff1a\u5982\u679c\u8bad\u7ec3\u6570\u636e\u8d85\u8fc7\u767e\u4e07\u6761\uff0c\u5efa\u8bae\u4f7f\u7528lazy_loading\u6a21\u5f0f\uff0c\u51cf\u5c11\u5185\u5b58\u5360\u7528\n\n#### \u591a\u6807\u7b7e\u5206\u7c7b\u6a21\u578b\n\u5206\u7c7b\u53ef\u4ee5\u5206\u4e3a\u591a\u5206\u7c7b\u548c\u591a\u6807\u7b7e\u5206\u7c7b\u3002\u591a\u5206\u7c7b\u7684\u6807\u7b7e\u662f\u6392\u4ed6\u7684\uff0c\u800c\u591a\u6807\u7b7e\u5206\u7c7b\u7684\u6240\u6709\u6807\u7b7e\u662f\u4e0d\u6392\u4ed6\u7684\u3002\n\n\u591a\u6807\u7b7e\u5206\u7c7b\u6bd4\u8f83\u76f4\u89c2\u7684\u7406\u89e3\u662f\uff0c\u4e00\u4e2a\u6837\u672c\u53ef\u4ee5\u540c\u65f6\u62e5\u6709\u51e0\u4e2a\u7c7b\u522b\u6807\u7b7e\uff0c\n\u6bd4\u5982\u4e00\u9996\u6b4c\u7684\u6807\u7b7e\u53ef\u4ee5\u662f\u6d41\u884c\u3001\u8f7b\u5feb\uff0c\u4e00\u90e8\u7535\u5f71\u7684\u6807\u7b7e\u53ef\u4ee5\u662f\u52a8\u4f5c\u3001\u559c\u5267\u3001\u641e\u7b11\u7b49\uff0c\u8fd9\u90fd\u662f\u591a\u6807\u7b7e\u5206\u7c7b\u7684\u60c5\u51b5\u3002\n\n\u8bad\u7ec3\u548c\u9884\u6d4b`BERT`\u591a\u6807\u7b7e\u5206\u7c7b\u6a21\u578b\uff0c\u793a\u4f8b[examples/bert_multilabel_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_multilabel_classification_zh_demo.py)\n\n```python\nimport sys\nimport pandas as pd\n\nsys.path.append('..')\nfrom pytextclassifier import BertClassifier\n\n\ndef load_jd_data(file_path):\n \"\"\"\n Load jd data from file.\n @param file_path: \n format: content,\u5176\u4ed6,\u4e92\u8054\u4e92\u901a,\u4ea7\u54c1\u529f\u8017,\u6ed1\u8f6e\u63d0\u624b,\u58f0\u97f3,APP\u64cd\u63a7\u6027,\u547c\u5438\u706f,\u5916\u89c2,\u5e95\u5ea7,\u5236\u70ed\u8303\u56f4,\u9065\u63a7\u5668\u7535\u6c60,\u5473\u9053,\u5236\u70ed\u6548\u679c,\u8863\u7269\u70d8\u5e72,\u4f53\u79ef\u5927\u5c0f\n @return: \n \"\"\"\n data = []\n with open(file_path, 'r', encoding='utf-8') as f:\n for line in f:\n line = line.strip()\n if line.startswith('#'):\n continue\n if not line:\n continue\n terms = line.split(',')\n if len(terms) != 16:\n continue\n val = [int(i) for i in terms[1:]]\n data.append([terms[0], val])\n return data\n\n\nif __name__ == '__main__':\n # model_type: support 'bert', 'albert', 'roberta', 'xlnet'\n # model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...\n m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15,\n model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)\n # Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.\n train_data = [\n [\"\u4e00\u4e2a\u5c0f\u65f6\u623f\u95f4\u4ecd\u7136\u6ca1\u6696\u548c\", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],\n [\"\u8017\u7535\u60c5\u51b5\uff1a\u8fd9\u4e2a\u6ca1\u6709\u6ce8\u610f\", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n ]\n data = load_jd_data('multilabel_jd_comments.csv')\n train_data.extend(data)\n print(train_data[:5])\n train_df = pd.DataFrame(train_data, columns=[\"text\", \"labels\"])\n\n print(train_df.head())\n m.train(train_df)\n print(m)\n # Evaluate the model\n acc_score = m.evaluate_model(train_df[:20])\n print(f'acc_score: {acc_score}')\n\n # load trained best model from model_dir\n m.load_model()\n predict_label, predict_proba = m.predict(['\u4e00\u4e2a\u5c0f\u65f6\u623f\u95f4\u4ecd\u7136\u6ca1\u6696\u548c', '\u8017\u7535\u60c5\u51b5\uff1a\u8fd9\u4e2a\u6ca1\u6709\u6ce8\u610f'])\n print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')\n```\n\n#### \u591a\u5c42\u7ea7\u5206\u7c7b\u6a21\u578b\n**\u591a\u5c42\u7ea7\u6807\u7b7e\u5206\u7c7b\u4efb\u52a1**\uff0c\u5982\u884c\u4e1a\u5206\u7c7b\uff08\u4e00\u7ea7\u884c\u4e1a\u4e0b\u5206\u4e8c\u7ea7\u5b50\u884c\u4e1a\uff0c\u518d\u5206\u4e09\u7ea7\uff09\u3001\u4ea7\u54c1\u5206\u7c7b\uff0c\u53ef\u4ee5\u4f7f\u7528\u591a\u6807\u7b7e\u5206\u7c7b\u6a21\u578b\uff0c\u5c06\u591a\u5c42\u7ea7\u6807\u7b7e\u8f6c\u6362\u4e3a\u591a\u6807\u7b7e\u5f62\u5f0f\uff0c\n\u793a\u4f8b[examples/bert_hierarchical_classification_zh_demo.py.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/bert_hierarchical_classification_zh_demo.py)\n\n\n#### ONNX\u63a8\u7406\u52a0\u901f\n\n\u652f\u6301\u5c06\u8bad\u7ec3\u597d\u7684\u6a21\u578b\u5bfc\u51fa\u4e3aONNX\u683c\u5f0f\uff0c\u4ee5\u4fbf\u63a8\u7406\u52a0\u901f\uff0c\u6216\u8005\u5728\u5176\u4ed6\u73af\u5883\u5982C++\u90e8\u7f72\u6a21\u578b\u8c03\u7528\u3002\n\n- GPU\u73af\u5883\u4e0b\u5bfc\u51faONNX\u6a21\u578b\uff0c\u7528ONNX\u6a21\u578b\u63a8\u7406\uff0c\u53ef\u4ee5\u83b7\u5f9710\u500d\u4ee5\u4e0a\u7684\u63a8\u7406\u52a0\u901f\uff0c\u9700\u8981\u5b89\u88c5`onnxruntime-gpu`\u5e93\uff1a`pip install onnxruntime-gpu`\n- CPU\u73af\u5883\u4e0b\u5bfc\u51faONNX\u6a21\u578b\uff0c\u7528ONNX\u6a21\u578b\u63a8\u7406\uff0c\u53ef\u4ee5\u83b7\u5f976\u500d\u4ee5\u4e0a\u7684\u63a8\u7406\u52a0\u901f\uff0c\u9700\u8981\u5b89\u88c5`onnxruntime`\u5e93\uff1a`pip install onnxruntime`\n\n\u793a\u4f8b[examples/onnx_predict_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/onnx_predict_demo.py)\n\n```python\nimport os\nimport shutil\nimport sys\nimport time\n\nimport torch\n\nsys.path.append('..')\nfrom pytextclassifier import BertClassifier\n\nm = BertClassifier(output_dir='models/bert-chinese-v1', num_classes=2,\n model_type='bert', model_name='bert-base-chinese', num_epochs=1)\ndata = [\n ('education', '\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7\uff1a\u540d\u8bcd\u7684\u590d\u6570\u5f62\u5f0f'),\n ('education', '\u4e2d\u56fd\u9ad8\u8003\u6210\u7ee9\u6d77\u5916\u8ba4\u53ef \u662f\u201c\u72fc\u6765\u4e86\u201d\u5417\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b1\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b2\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b3\uff1f'),\n ('education', '\u516c\u52a1\u5458\u8003\u8651\u8d8a\u6765\u8d8a\u5403\u9999\uff0c\u8fd9\u662f\u600e\u4e48\u56de\u4e8b4\uff1f'),\n ('sports', '\u56fe\u6587\uff1a\u6cd5\u7f51\u5b5f\u83f2\u5c14\u65af\u82e6\u6218\u8fdb16\u5f3a \u5b5f\u83f2\u5c14\u65af\u6012\u543c'),\n ('sports', '\u56db\u5ddd\u4e39\u68f1\u4e3e\u884c\u5168\u56fd\u957f\u8ddd\u767b\u5c71\u6311\u6218\u8d5b \u8fd1\u4e07\u4eba\u53c2\u4e0e'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc1'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc2'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc3'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc4'),\n ('sports', '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25\u56fd\u7c7310\u5e74\u8fde\u80dc5'),\n]\nm.train(data * 10)\nm.load_model()\n\nsamples = ['\u540d\u5e08\u6307\u5bfc\u6258\u798f\u8bed\u6cd5\u6280\u5de7',\n '\u7c73\u5170\u5ba2\u573a8\u6218\u4e0d\u8d25',\n '\u6052\u751fAH\u6ea2\u6307\u6536\u5e73 A\u80a1\u5bf9H\u80a1\u6298\u4ef71.95%'] * 100\n\nstart_time = time.time()\npredict_label_bert, predict_proba_bert = m.predict(samples)\nprint(f'predict_label_bert size: {len(predict_label_bert)}')\nend_time = time.time()\nelapsed_time_bert = end_time - start_time\nprint(f'Standard BERT model prediction time: {elapsed_time_bert} seconds')\n\n# convert to onnx, and load onnx model to predict, speed up 10x\nsave_onnx_dir = 'models/bert-chinese-v1/onnx'\nm.model.convert_to_onnx(save_onnx_dir)\n# copy label_vocab.json to save_onnx_dir\nif os.path.exists(m.label_vocab_path):\n shutil.copy(m.label_vocab_path, save_onnx_dir)\n\n# Manually delete the model and clear CUDA cache\ndel m\ntorch.cuda.empty_cache()\n\nm = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='bert', model_name=save_onnx_dir,\n args={\"onnx\": True})\nm.load_model()\nstart_time = time.time()\npredict_label_bert, predict_proba_bert = m.predict(samples)\nprint(f'predict_label_bert size: {len(predict_label_bert)}')\nend_time = time.time()\nelapsed_time_onnx = end_time - start_time\nprint(f'ONNX model prediction time: {elapsed_time_onnx} seconds')\n```\n\n## Evaluation\n\n### Dataset\n\n1. THUCNews\u4e2d\u6587\u6587\u672c\u6570\u636e\u96c6\uff081.56GB\uff09\uff1a\u5b98\u65b9[\u4e0b\u8f7d\u5730\u5740](http://thuctc.thunlp.org/)\uff0c\u62bd\u6837\u4e8610\u4e07\u6761THUCNews\u4e2d\u6587\u6587\u672c10\u5206\u7c7b\u6570\u636e\u96c6\uff086MB\uff09\uff0c\u5730\u5740\uff1a[examples/thucnews_train_10w.txt](https://github.com/shibing624/pytextclassifier/blob/master/examples/thucnews_train_10w.txt)\u3002\n2. TNEWS\u4eca\u65e5\u5934\u6761\u4e2d\u6587\u65b0\u95fb\uff08\u77ed\u6587\u672c\uff09\u5206\u7c7b Short Text Classificaiton for News\uff0c\u8be5\u6570\u636e\u96c6(5.1MB)\u6765\u81ea\u4eca\u65e5\u5934\u6761\u7684\u65b0\u95fb\u7248\u5757\uff0c\u5171\u63d0\u53d6\u4e8615\u4e2a\u7c7b\u522b\u7684\u65b0\u95fb\uff0c\u5305\u62ec\u65c5\u6e38\uff0c\u6559\u80b2\uff0c\u91d1\u878d\uff0c\u519b\u4e8b\u7b49\uff0c\u5730\u5740\uff1a[tnews_public.zip](https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip)\n\n### Evaluation Result\n\u5728THUCNews\u4e2d\u6587\u6587\u672c10\u5206\u7c7b\u6570\u636e\u96c6\uff086MB\uff09\u4e0a\u8bc4\u4f30\uff0c\u6a21\u578b\u5728\u6d4b\u8bd5\u96c6(test)\u8bc4\u6d4b\u6548\u679c\u5982\u4e0b\uff1a\n\n\u6a21\u578b|acc|\u8bf4\u660e\n--|--|--\nLR|0.8803|\u903b\u8f91\u56de\u5f52Logistics Regression\nTextCNN|0.8809|Kim 2014 \u7ecf\u5178\u7684CNN\u6587\u672c\u5206\u7c7b\nTextRNN_Att|0.9022|BiLSTM+Attention\nFastText|0.9177|bow+bigram+trigram\uff0c \u6548\u679c\u51fa\u5947\u7684\u597d\nDPCNN|0.9125|\u6df1\u5c42\u91d1\u5b57\u5854CNN\nTransformer|0.8991|\u6548\u679c\u8f83\u5dee\nBERT-base|**0.9483**|bert + fc\nERNIE|0.9461|\u6bd4bert\u7565\u5dee\n\n\u5728\u4e2d\u6587\u65b0\u95fb\u77ed\u6587\u672c\u5206\u7c7b\u6570\u636e\u96c6TNEWS\u4e0a\u8bc4\u4f30\uff0c\u6a21\u578b\u5728\u5f00\u53d1\u96c6(dev)\u8bc4\u6d4b\u6548\u679c\u5982\u4e0b\uff1a\n\n\u6a21\u578b|acc|\u8bf4\u660e\n--|--|--\nBERT-base|**0.5660**|\u672c\u9879\u76ee\u5b9e\u73b0\nBERT-base|0.5609|CLUE Benchmark Leaderboard\u7ed3\u679c [CLUEbenchmark](https://github.com/CLUEbenchmark/CLUE)\n\n- \u4ee5\u4e0a\u7ed3\u679c\u5747\u4e3a\u5206\u7c7b\u7684\u51c6\u786e\u7387\uff08accuracy\uff09\u7ed3\u679c\n- THUCNews\u6570\u636e\u96c6\u8bc4\u6d4b\u7ed3\u679c\u53ef\u4ee5\u57fa\u4e8e`examples/thucnews_train_10w.txt`\u6570\u636e\u7528`examples`\u4e0b\u7684\u5404\u6a21\u578bdemo\u590d\u73b0\n- TNEWS\u6570\u636e\u96c6\u8bc4\u6d4b\u7ed3\u679c\u53ef\u4ee5\u4e0b\u8f7dTNEWS\u6570\u636e\u96c6\uff0c\u8fd0\u884c`examples/bert_classification_tnews_demo.py`\u590d\u73b0\n\n### \u547d\u4ee4\u884c\u8c03\u7528\n\n\u63d0\u4f9b\u5206\u7c7b\u6a21\u578b\u547d\u4ee4\u884c\u8c03\u7528\u811a\u672c\uff0c\u6587\u4ef6\u6811\uff1a\n```bash\npytextclassifier\n\u251c\u2500\u2500 bert_classifier.py\n\u251c\u2500\u2500 fasttext_classifier.py\n\u251c\u2500\u2500 classic_classifier.py\n\u251c\u2500\u2500 textcnn_classifier.py\n\u2514\u2500\u2500 textrnn_classifier.py\n```\n\n\u6bcf\u4e2a\u6587\u4ef6\u5bf9\u5e94\u4e00\u4e2a\u6a21\u578b\u65b9\u6cd5\uff0c\u5404\u6a21\u578b\u5b8c\u5168\u72ec\u7acb\uff0c\u53ef\u4ee5\u76f4\u63a5\u8fd0\u884c\uff0c\u4e5f\u65b9\u4fbf\u4fee\u6539\uff0c\u652f\u6301\u901a\u8fc7`argparse` \u4fee\u6539`--data_path`\u7b49\u53c2\u6570\u3002\n\n\u76f4\u63a5\u5728\u7ec8\u7aef\u8c03\u7528fasttext\u6a21\u578b\u8bad\u7ec3\uff1a\n```bash\npython -m pytextclassifier.fasttext_classifier -h\n```\n\n## Text Cluster\n\n\nText clustering, for example [examples/cluster_demo.py](https://github.com/shibing624/pytextclassifier/blob/master/examples/cluster_demo.py)\n\n```python\nimport sys\n\nsys.path.append('..')\nfrom pytextclassifier.textcluster import TextCluster\n\nif __name__ == '__main__':\n m = TextCluster(output_dir='models/cluster-toy', n_clusters=2)\n print(m)\n data = [\n 'Student debt to cost Britain billions within decades',\n 'Chinese education for TV experiment',\n 'Abbott government spends $8 million on higher education',\n 'Middle East and Asia boost investment in top level sports',\n 'Summit Series look launches HBO Canada sports doc series: Mudhar'\n ]\n m.train(data)\n m.load_model()\n r = m.predict(['Abbott government spends $8 million on higher education media blitz',\n 'Middle East and Asia boost investment in top level sports'])\n print(r)\n\n ########### load chinese train data from 1w data file\n from sklearn.feature_extraction.text import TfidfVectorizer\n\n tcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10)\n data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\\t', use_col=1)\n feature, labels = tcluster.train(data[:5000])\n tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png')\n r = tcluster.predict(data[:30])\n print(r)\n```\n\noutput:\n\n```\nTextCluster instance (MiniBatchKMeans(n_clusters=2, n_init=10), <pytextclassifier.utils.tokenizer.Tokenizer object at 0x7f80bd4682b0>, TfidfVectorizer(ngram_range=(1, 2)))\n[1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 9 1 1 8 1 1 9 1]\n```\nclustering plot image:\n\n![cluster_image](https://github.com/shibing624/pytextclassifier/blob/master/docs/cluster_train_seg_samples.png)\n\n\n## Contact\n\n- Issue(\u5efa\u8bae)\uff1a[![GitHub issues](https://img.shields.io/github/issues/shibing624/pytextclassifier.svg)](https://github.com/shibing624/pytextclassifier/issues)\n- \u90ae\u4ef6\u6211\uff1axuming: xuming624@qq.com\n- \u5fae\u4fe1\u6211\uff1a\u52a0\u6211*\u5fae\u4fe1\u53f7\uff1axuming624*, \u8fdbPython-NLP\u4ea4\u6d41\u7fa4\uff0c\u5907\u6ce8\uff1a*\u59d3\u540d-\u516c\u53f8\u540d-NLP*\n<img src=\"docs/wechat.jpeg\" width=\"200\" />\n\n\n## Citation\n\n\u5982\u679c\u4f60\u5728\u7814\u7a76\u4e2d\u4f7f\u7528\u4e86pytextclassifier\uff0c\u8bf7\u6309\u5982\u4e0b\u683c\u5f0f\u5f15\u7528\uff1a\n\nAPA:\n```latex\nXu, M. Pytextclassifier: Text classifier toolkit for NLP (Version 1.2.0) [Computer software]. https://github.com/shibing624/pytextclassifier\n```\n\nBibTeX:\n```latex\n@misc{Pytextclassifier,\n title={Pytextclassifier: Text classifier toolkit for NLP},\n author={Xu Ming},\n year={2022},\n howpublished={\\url{https://github.com/shibing624/pytextclassifier}},\n}\n```\n\n\n## License\n\n\n\u6388\u6743\u534f\u8bae\u4e3a [The Apache License 2.0](LICENSE)\uff0c\u53ef\u514d\u8d39\u7528\u505a\u5546\u4e1a\u7528\u9014\u3002\u8bf7\u5728\u4ea7\u54c1\u8bf4\u660e\u4e2d\u9644\u52a0**pytextclassifier**\u7684\u94fe\u63a5\u548c\u6388\u6743\u534f\u8bae\u3002\n\n\n## Contribute\n\u9879\u76ee\u4ee3\u7801\u8fd8\u5f88\u7c97\u7cd9\uff0c\u5982\u679c\u5927\u5bb6\u5bf9\u4ee3\u7801\u6709\u6240\u6539\u8fdb\uff0c\u6b22\u8fce\u63d0\u4ea4\u56de\u672c\u9879\u76ee\uff0c\u5728\u63d0\u4ea4\u4e4b\u524d\uff0c\u6ce8\u610f\u4ee5\u4e0b\u4e24\u70b9\uff1a\n\n - \u5728`tests`\u6dfb\u52a0\u76f8\u5e94\u7684\u5355\u5143\u6d4b\u8bd5\n - \u4f7f\u7528`python setup.py test`\u6765\u8fd0\u884c\u6240\u6709\u5355\u5143\u6d4b\u8bd5\uff0c\u786e\u4fdd\u6240\u6709\u5355\u6d4b\u90fd\u662f\u901a\u8fc7\u7684\n\n\u4e4b\u540e\u5373\u53ef\u63d0\u4ea4PR\u3002\n\n\n\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": "Text Classifier, Text Classification",
"version": "1.4.0",
"project_urls": {
"Homepage": "https://github.com/shibing624/pytextclassifier"
},
"split_keywords": [
"pytextclassifier",
" textclassifier",
" classifier",
" textclassification"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "423caa93299704c9a7c17e96f6bb060b4ceaff56932f931d26610a5c9aef4316",
"md5": "bf07d08fd7b1893eb972ccecfacbfd78",
"sha256": "25951e71b993be1c784cced4de638c2add19868d2b8faa272d63ce8fe8425bb7"
},
"downloads": -1,
"filename": "pytextclassifier-1.4.0.tar.gz",
"has_sig": false,
"md5_digest": "bf07d08fd7b1893eb972ccecfacbfd78",
"packagetype": "sdist",
"python_version": "source",
"requires_python": null,
"size": 398594,
"upload_time": "2024-07-31T14:52:29",
"upload_time_iso_8601": "2024-07-31T14:52:29.711733Z",
"url": "https://files.pythonhosted.org/packages/42/3c/aa93299704c9a7c17e96f6bb060b4ceaff56932f931d26610a5c9aef4316/pytextclassifier-1.4.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-07-31 14:52:29",
"github": true,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"github_user": "shibing624",
"github_project": "pytextclassifier",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"requirements": [
{
"name": "loguru",
"specs": []
},
{
"name": "jieba",
"specs": []
},
{
"name": "scikit-learn",
"specs": []
},
{
"name": "pandas",
"specs": []
},
{
"name": "numpy",
"specs": []
},
{
"name": "transformers",
"specs": []
}
],
"lcname": "pytextclassifier"
}