微调LayoutLM v3进行票据数据的处理和内容识别
时间:2022-10-27 18:30:00
文档理解是文档处理和提取中最重要的步骤。这是一个从非结构化或半结构化文档中提取信息并将其转换为结构化形式的过程。提取后的结构化表示,它可以支持信息检索、汇总、分类等各种下游任务。理解文档的方法有很多不同,但它们都有一个共同的目标:创建文档内容的结构化表示,以便进一步处理。
发票、收款或合同等半结构化文件,Microsoft的Layoutlm模型工作得很好。
在这篇文章中,我们将是微软的最新Layoutlm V3.微调其性能和性能Layoutlm V比较2模型。
LayoutLM v3
LayoutLM v与前两个版本相比,多模态的主要优势是transformer 结构,它以统一的方式嵌入文本和图像。文档图像不依赖CNN处理,但将图像补丁块表示为线性投影,然后线性嵌入与文本标记对齐,如下图所示。该方法的主要优点是减少所需的参数和整体计算量。
论文作者说:LayoutLMv3不仅在以文本为中心的任务(包括表格理解、票据理解和文档视觉问题回答)中实现了最先进的性能,而且在以图像为中心的任务(如文档图像分类和文档布局分析)中实现了最先进的性能。
微调LayoutLM v3
我们将使用相同的220个带注释的发票数据集进行微调layoutLM v3模型。为了标记,我用了UBIAI支持文本注释工具的文本注释工具OCR解析,原生PDF/图像注释,可以使用LayoutLM导出模型兼容的格式,可以节省后处理工作。
从UBIAI在导出注释文件后,我们将使用谷歌colab进行模型训练和推理。最后提供源代码地址,这里简要介绍工作流程
第一步是打开colab,安装相应的库layoutLMv不同的是,我们没有使用它detectron 2包微调实体提取的模型。但是,需要使用布局测试(不在本文讨论范围内)detectorn 2包:
from google.colab import drive drive.mount('/content/drive') !pip install -q git https://github.com/huggingface/transformers.git !pip install -q git https://github.com/huggingface/datasets.git "dill<0.3.5" seqeval
接下来,使用preprocess.py脚本来处理UBIAI导出的ZIP文件:
! rm -r layoutlmv3FineTuning ! git clone -b main https://github.com/UBIAI/layoutlmv3FineTuning.git #!/bin/bash IOB_DATA_PATH = "/content/drive/MyDrive/LayoutLM_data/Invoice_Project_mkWSi4Z.zip" ! cd /content/ ! rm -r data! mkdir data ! cp "$IOB_DATA_PATH" data/dataset.zip ! cd data && unzip -q dataset && rm dataset.zip ! cd ..
运行预处理脚本:
#!/bin/bash TEST_SIZE = 0.33 DATA_OUTPUT_PATH = "/content/" ! python3 layoutlmv3FineTuning/preprocess.py --valid_size $TEST_SIZE --output_path $DATA_OUTPUT_PATH
加载处理后的数据集:
from datasets import load_metric from transformers import TrainingArguments, Trainer from transformers import LayoutLMv3ForTokenClassification,AutoProcessor from transformers.data.data_collator import default_data_collator import torch from datasets import load_from_disk train_dataset = load_from_disk(f'/content/train_split') eval_dataset = load_from_disk(f'/content/eval_split') label_list = train_dataset.features["labels"].feature.names num_labels = len(label_list) label2id, id2label = dict(), dict() for i, label in enumerate(label_list): label2id[label] = i id2label[i] = label
评价指标的定义:
metric = load_metric("seqeval") import numpy as np return_entity_level_metrics = False def compute_metrics(p): predictions, labels = p predictions = np.argmax(predictions, axis=2) true_predictions = [ [label_list[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] true_labels = [ [label_list[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] results = metric.compute(predictions=true_predictions, references=true_labels,zero_division='0') if return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], }
训练和评估模型:
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", id2label=id2label, label2id=label2id) processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) NUM_TRAIN_EPOCHS = 50 PER_DEVICE_TRAIN_BATCH_SIZE = 1 PER_DEVICE_EVAL_BATCH_SIZE = 1 LEARNING_RATE = 4e-5 training_args = TrainingArguments(output_dir="test", # max_steps=1500, num_train_epochs=NUM_TRAIN_EPOCHS, logging_strategy="epoch", save_total_limit=1, per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
learning_rate=LEARNING_RATE,
evaluation_strategy="epoch",
save_strategy="epoch",
# eval_steps=100,
load_best_model_at_end=True,
metric_for_best_model="f1")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processor,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()
训练完成后,对测试数据集进行评估。以下为评价后的模型得分:
{'epoch': 50.0,
'eval_accuracy': 0.9521988527724665,
'eval_f1': 0.6913439635535308,
'eval_loss': 0.41490793228149414,
'eval_precision': 0.6362683438155137,
'eval_recall': 0.756857855361596,
'eval_runtime': 9.7501,
'eval_samples_per_second': 9.846,
'eval_steps_per_second': 9.846}
该模型f1得分为0.69,召回率为0.75,准确率为0.63。
让我们在不属于训练数据集的新发票上运行模型。
使用LayoutLM v3进行预测
为了进行预测,我们将使用Tesseract对发票进行OCR,并将信息输入到训练好的模型中进行预测。为了简化这一过程,我创建了一个自定义脚本,其中只包含几行代码,允许接收OCR输出并使用模型运行预测。
第一步,让我们导入一些重要的库并加载模型:
from google.colab import drive
drive.mount('/content/drive')
!pip install -q git+https://github.com/huggingface/transformers.git
! sudo apt install tesseract-ocr
! sudo apt install libtesseract-dev
! pip install pytesseract
! git clone https://github.com/salmenhsairi/layoutlmv3FineTuning.git
import os
import torch
import warnings
from PIL import Image
warnings.filterwarnings('ignore')
os.makedirs('/content/images',exist_ok=True)
for image in os.listdir():
try:
img = Image.open(f'{os.curdir}/{image}')
os.system(f'mv "{image}" "images/{image}"')
except:
pass
model_path = "/content/drive/MyDrive/LayoutLM_data/layoutlmv3.pth" # path to Layoutlmv3 model
imag_path = "/content/images" # images folder
if model_path.endswith('.pth'):
layoutlmv3_model = torch.load(model_path)
model_path = '/content/pre_trained_layoutlmv3'
layoutlmv3_model.save_pretrained(model_path)
使用模型进行预测
我们使用了220张带标注的发票进行训练,该模型能够正确预测卖方名称、日期、发票编号和总价(TTC)!
如果仔细观察,就会发现把笔记本电脑总价当作发票总价的做法是错误的(上图)。这并不奇怪,我们可以用更多的训练数据来解决这个问题。
比较LayoutLM v2和LayoutLM v3
除了计算量更少之外,layoutLM V3是否比它的v2版本提供了性能提升?为了回答这个问题,我们比较了相同发票的两个模型输出。下面相同数据下layoutLM v2输出:
v3模型能够正确地检测到大多数的项目,而v2不能检测invoice_ID、发票number_ID和Total_ID
v2型号错误地将Total price $ 1445 .00标为MONTANT_HT(法语中是税前总价),而v3正确地预测了总价。
两个模型都错误地将笔记本电脑的价格标为Total。
基于这个例子,layoutLM V3显示了更好的整体性能,但我们需要在更大的数据集上进行测试。
总结
本文中展示了如何在发票数据提取的特定用例上微调layoutLM V3。然后将其性能与layoutLM V2进行了比较,发现它的性能略有提高,但仍需要在更大的数据集上验证。
基于性能和计算收益,我强烈建议使用新的layoutLM v3。
本文的一些有用的资料:
https://avoid.overfit.cn/post/be399d8f17f542929155b8b2481ecaaa
作者:Walid Amamou