antalvdb's picture
Upload 3 files
e72d5c9
raw
history blame contribute delete
No virus
2.76 kB
from transformers import pipeline, BartForConditionalGeneration, AutoTokenizer
from evaluate import load
import re
model = BartForConditionalGeneration.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')
tokenizer = AutoTokenizer.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')
fix_spelling = pipeline("text2text-generation",model=model,tokenizer=tokenizer)
cer = load("cer")
wer = load("wer")
bleu = load("bleu")
meteor = load("meteor")
file1name = 'opentaal-annotaties.txt.errors'
file2name = 'opentaal-annotaties.txt.corrections'
predictions=[]
references=[]
counter=0;
#clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
clean_chars = re.compile(r'[^A-Za-zëïöäüÖÄÜ,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
def cleanup(text):
text = clean_chars.sub('', text)
#print("bug: somehow all numbers are removed - this is might be due to thisregex")
#exit()
#text = text.replace("\n", "")
#text = text.replace('"','\\"')
return text
with open(file1name, "r") as file1, open(file2name, "r") as file2:
for line1, line2 in zip(file1, file2):
line1 = cleanup(line1)
# for actual spelling correction evaluation:
intermediate=(fix_spelling(line1,max_length=2048))
line=intermediate[0]['generated_text'];
# for lower-bound testing on the errors:
#line = line1
print(line1)
print(line)
line2 = cleanup(line2)
print(line2)
if len(line)>0 and len(line2)>0:
predictions.append(line)
references.append(line2)
if counter%100==0:
print(counter)
cer_score = cer.compute(predictions=predictions, references=references)
print('CER - ' + str(cer_score))
wer_score = wer.compute(predictions=predictions, references=references)
print('WER - ' + str(wer_score))
bleu_score = bleu.compute(predictions=predictions, references=references)
print('BLEU - ' + str(bleu_score))
meteor_score = meteor.compute(predictions=predictions, references=references)
print('METEOR - ' + str(meteor_score))
counter+=1
cer_score = cer.compute(predictions=predictions, references=references)
print('CER - ' + str(cer_score))
wer_score = wer.compute(predictions=predictions, references=references)
print('WER - ' + str(wer_score))
bleu_score = bleu.compute(predictions=predictions, references=references)
print('BLEU - ' + str(bleu_score))
meteor_score = meteor.compute(predictions=predictions, references=references)
print('METEOR - ' + str(meteor_score))