bert_pretty is a text encoder and result decoder
```py
# -*- coding:utf-8 -*-
'''
bert input_instance encode and result decode
https://github.com/ssbuild/bert_pretty.git
'''
import numpy as np
#FullTokenizer is official and you can use your tokenization .
from bert_pretty import FullTokenizer,\
text_feature, \
text_feature_char_level,\
text_feature_word_level,\
text_feature_char_level_input_ids_mask, \
text_feature_word_level_input_ids_mask, \
text_feature_char_level_input_ids_segment, \
text_feature_word_level_input_ids_segment, \
seqs_padding,rematch
from bert_pretty.ner import load_label_bioes,load_label_bio,load_labels as ner_load_labels
from bert_pretty.ner import ner_crf_decoding,\
ner_pointer_decoding,\
ner_pointer_decoding_with_mapping,\
ner_pointer_double_decoding,ner_pointer_double_decoding_with_mapping
from bert_pretty.cls import cls_softmax_decoding,cls_sigmoid_decoding,load_labels as cls_load_labels
tokenizer = FullTokenizer(vocab_file=r'F:\pretrain\chinese_L-12_H-768_A-12\vocab.txt',do_lower_case=True)
text_list = ["你是谁123aa\ta嘂a","嘂adasd"]
def test():
maxlen = 512
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]
mapping = [rematch(text, tokens, do_lower_case) for text, tokens in zip(text_list, inputs)]
inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]
input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)
input_ids = np.asarray(input_ids, dtype=np.int32)
input_mask = np.asarray(input_mask, dtype=np.int32)
input_segment = np.asarray(input_segment, dtype=np.int32)
print('input_ids\n', input_ids)
print('mapping\n',mapping)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')
def test_charlevel():
do_lower_case = tokenizer.basic_tokenizer.do_lower_case
maxlen = 512
if do_lower_case:
inputs = [['[CLS]'] + tokenizer.tokenize(text.lower())[:maxlen - 2] + ['[SEP]'] for text in text_list]
else:
inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]
inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]
input_mask = [[1] * len(input) for input in inputs]
input_segment = [[0] * len(input) for input in inputs]
input_ids = seqs_padding(inputs)
input_mask = seqs_padding(input_mask)
input_segment = seqs_padding(input_segment)
input_ids = np.asarray(input_ids, dtype=np.int32)
input_mask = np.asarray(input_mask, dtype=np.int32)
input_segment = np.asarray(input_segment, dtype=np.int32)
print('input_ids\n', input_ids)
print('input_mask\n',input_mask)
print('input_segment\n',input_segment)
print('\n\n')
# labels = ['标签1','标签2']
# print(cls.load_labels(labels))
#
# print(ner.load_label_bio(labels))
'''
# def ner_crf_decoding(batch_text, id2label, batch_logits, trans=None,batch_mapping=None,with_dict=True):
ner crf decode 解析crf序列 or 解析 已经解析过的crf序列
batch_text input_instance list ,
id2label 标签 list or dict
batch_logits 为bert 预测结果 logits_all (batch,seq_len,num_tags) or (batch,seq_len)
trans 是否启用trans预测 , 2D
batch_mapping 映射序列
'''
'''
def ner_pointer_decoding(batch_text, id2label, batch_logits, threshold=1e-8,coordinates_minus=False,with_dict=True)
batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''
'''
def ner_pointer_decoding_with_mapping(batch_text, id2label, batch_logits, batch_mapping,threshold=1e-8,coordinates_minus=False,with_dict=True)
batch_text text list ,
id2label 标签 list or dict
batch_logits (batch,num_labels,seq_len,seq_len)
threshold 阈值
coordinates_minus
'''
'''
cls_softmax_decoding(batch_text, id2label, batch_logits,threshold=None)
batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''
'''
cls_sigmoid_decoding(batch_text, id2label, batch_logits,threshold=0.5)
batch_text 文本list ,
id2label 标签 list or dict
batch_logits (batch,num_classes)
threshold 阈值
'''
def test_cls_decode():
num_label =3
np.random.seed(123)
batch_logits = np.random.rand(2,num_label)
result = cls_softmax_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=None)
print(result)
batch_logits = np.random.rand(2,num_label)
print(batch_logits)
result = cls_sigmoid_decoding(text_list,['标签1','标签2','标签3'],batch_logits,threshold=0.5)
print(result)
if __name__ == '__main__':
test()
test_charlevel()
test_cls_decode()
```
Raw data
{
"_id": null,
"home_page": "https://github.com/ssbuild/bert_pretty",
"name": "bert-pretty",
"maintainer": "",
"docs_url": null,
"requires_python": ">=3, <4",
"maintainer_email": "",
"keywords": "bert_pretty,bert_pretty,bert text pretty,bert decording",
"author": "ssbuild",
"author_email": "9727464@qq.com",
"download_url": "",
"platform": "win32_AMD64",
"description": "bert_pretty is a text encoder and result decoder\n\n```py\n# -*- coding:utf-8 -*-\n'''\n bert input_instance encode and result decode\n https://github.com/ssbuild/bert_pretty.git\n'''\nimport numpy as np\n#FullTokenizer is official and you can use your tokenization .\nfrom bert_pretty import FullTokenizer,\\\n text_feature, \\\n text_feature_char_level,\\\n text_feature_word_level,\\\n text_feature_char_level_input_ids_mask, \\\n text_feature_word_level_input_ids_mask, \\\n text_feature_char_level_input_ids_segment, \\\n text_feature_word_level_input_ids_segment, \\\n seqs_padding,rematch\n\n\nfrom bert_pretty.ner import load_label_bioes,load_label_bio,load_labels as ner_load_labels\nfrom bert_pretty.ner import ner_crf_decoding,\\\n ner_pointer_decoding,\\\n ner_pointer_decoding_with_mapping,\\\n ner_pointer_double_decoding,ner_pointer_double_decoding_with_mapping\n\nfrom bert_pretty.cls import cls_softmax_decoding,cls_sigmoid_decoding,load_labels as cls_load_labels\n\n\ntokenizer = FullTokenizer(vocab_file=r'F:\\pretrain\\chinese_L-12_H-768_A-12\\vocab.txt',do_lower_case=True)\ntext_list = [\"\u4f60\u662f\u8c01123aa\\ta\u5602a\",\"\u5602adasd\"]\n\n\n\ndef test():\n maxlen = 512\n do_lower_case = tokenizer.basic_tokenizer.do_lower_case\n inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]\n mapping = [rematch(text, tokens, do_lower_case) for text, tokens in zip(text_list, inputs)]\n inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]\n input_mask = [[1] * len(input) for input in inputs]\n input_segment = [[0] * len(input) for input in inputs]\n input_ids = seqs_padding(inputs)\n input_mask = seqs_padding(input_mask)\n input_segment = seqs_padding(input_segment)\n\n input_ids = np.asarray(input_ids, dtype=np.int32)\n input_mask = np.asarray(input_mask, dtype=np.int32)\n input_segment = np.asarray(input_segment, dtype=np.int32)\n\n print('input_ids\\n', input_ids)\n print('mapping\\n',mapping)\n print('input_mask\\n',input_mask)\n print('input_segment\\n',input_segment)\n print('\\n\\n')\n\n\n\ndef test_charlevel():\n do_lower_case = tokenizer.basic_tokenizer.do_lower_case\n maxlen = 512\n if do_lower_case:\n inputs = [['[CLS]'] + tokenizer.tokenize(text.lower())[:maxlen - 2] + ['[SEP]'] for text in text_list]\n else:\n inputs = [['[CLS]'] + tokenizer.tokenize(text)[:maxlen - 2] + ['[SEP]'] for text in text_list]\n inputs = [tokenizer.convert_tokens_to_ids(input) for input in inputs]\n input_mask = [[1] * len(input) for input in inputs]\n input_segment = [[0] * len(input) for input in inputs]\n input_ids = seqs_padding(inputs)\n input_mask = seqs_padding(input_mask)\n input_segment = seqs_padding(input_segment)\n\n input_ids = np.asarray(input_ids, dtype=np.int32)\n input_mask = np.asarray(input_mask, dtype=np.int32)\n input_segment = np.asarray(input_segment, dtype=np.int32)\n\n print('input_ids\\n', input_ids)\n print('input_mask\\n',input_mask)\n print('input_segment\\n',input_segment)\n print('\\n\\n')\n\n# labels = ['\u6807\u7b7e1','\u6807\u7b7e2']\n# print(cls.load_labels(labels))\n#\n# print(ner.load_label_bio(labels))\n\n\n'''\n # def ner_crf_decoding(batch_text, id2label, batch_logits, trans=None,batch_mapping=None,with_dict=True):\n ner crf decode \u89e3\u6790crf\u5e8f\u5217 or \u89e3\u6790 \u5df2\u7ecf\u89e3\u6790\u8fc7\u7684crf\u5e8f\u5217\n\n batch_text input_instance list , \n id2label \u6807\u7b7e list or dict\n batch_logits \u4e3abert \u9884\u6d4b\u7ed3\u679c logits_all (batch,seq_len,num_tags) or (batch,seq_len)\n trans \u662f\u5426\u542f\u7528trans\u9884\u6d4b , 2D \n batch_mapping \u6620\u5c04\u5e8f\u5217\n'''\n\n'''\n def ner_pointer_decoding(batch_text, id2label, batch_logits, threshold=1e-8,coordinates_minus=False,with_dict=True)\n\n batch_text text list , \n id2label \u6807\u7b7e list or dict\n batch_logits (batch,num_labels,seq_len,seq_len)\n threshold \u9608\u503c\n coordinates_minus\n'''\n\n'''\n def ner_pointer_decoding_with_mapping(batch_text, id2label, batch_logits, batch_mapping,threshold=1e-8,coordinates_minus=False,with_dict=True)\n\n batch_text text list , \n id2label \u6807\u7b7e list or dict\n batch_logits (batch,num_labels,seq_len,seq_len)\n threshold \u9608\u503c\n coordinates_minus\n'''\n\n\n'''\n cls_softmax_decoding(batch_text, id2label, batch_logits,threshold=None)\n batch_text \u6587\u672clist , \n id2label \u6807\u7b7e list or dict\n batch_logits (batch,num_classes)\n threshold \u9608\u503c\n'''\n\n'''\n cls_sigmoid_decoding(batch_text, id2label, batch_logits,threshold=0.5)\n\n batch_text \u6587\u672clist , \n id2label \u6807\u7b7e list or dict\n batch_logits (batch,num_classes)\n threshold \u9608\u503c\n'''\n\n\ndef test_cls_decode():\n num_label =3\n np.random.seed(123)\n batch_logits = np.random.rand(2,num_label)\n result = cls_softmax_decoding(text_list,['\u6807\u7b7e1','\u6807\u7b7e2','\u6807\u7b7e3'],batch_logits,threshold=None)\n print(result)\n\n\n batch_logits = np.random.rand(2,num_label)\n print(batch_logits)\n result = cls_sigmoid_decoding(text_list,['\u6807\u7b7e1','\u6807\u7b7e2','\u6807\u7b7e3'],batch_logits,threshold=0.5)\n print(result)\n\n\n\n\n\nif __name__ == '__main__':\n test()\n test_charlevel()\n test_cls_decode()\n\n\n\n\n\n\n\n```\n\n\n",
"bugtrack_url": null,
"license": "Apache 2.0",
"summary": "bert_pretty is a text encoder and result decoder",
"version": "0.1.0.post0",
"split_keywords": [
"bert_pretty",
"bert_pretty",
"bert text pretty",
"bert decording"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "c19c9acd5b3443ae4079de0c83bdfe79ee04e2fa9b82634f9092ff9fb85092e6",
"md5": "9008d9f1f3813bf27c6edbb88c574466",
"sha256": "52ff286e28f17f487c3486cced6c6a8db49a901f289370a18ee50cda52aa4b00"
},
"downloads": -1,
"filename": "bert_pretty-0.1.0.post0-py3-none-any.whl",
"has_sig": false,
"md5_digest": "9008d9f1f3813bf27c6edbb88c574466",
"packagetype": "bdist_wheel",
"python_version": "py3",
"requires_python": ">=3, <4",
"size": 31499,
"upload_time": "2023-02-02T08:27:41",
"upload_time_iso_8601": "2023-02-02T08:27:41.145902Z",
"url": "https://files.pythonhosted.org/packages/c1/9c/9acd5b3443ae4079de0c83bdfe79ee04e2fa9b82634f9092ff9fb85092e6/bert_pretty-0.1.0.post0-py3-none-any.whl",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2023-02-02 08:27:41",
"github": true,
"gitlab": false,
"bitbucket": false,
"github_user": "ssbuild",
"github_project": "bert_pretty",
"travis_ci": false,
"coveralls": false,
"github_actions": false,
"lcname": "bert-pretty"
}