File size: 2,756 Bytes
e72d5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from transformers import pipeline, BartForConditionalGeneration, AutoTokenizer
from evaluate import load
import re

model = BartForConditionalGeneration.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')
tokenizer = AutoTokenizer.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3')

fix_spelling = pipeline("text2text-generation",model=model,tokenizer=tokenizer)
cer = load("cer")
wer = load("wer")
bleu = load("bleu")
meteor = load("meteor")

file1name = 'opentaal-annotaties.txt.errors'
file2name = 'opentaal-annotaties.txt.corrections'

predictions=[]
references=[]

counter=0;

#clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
clean_chars = re.compile(r'[^A-Za-zëïöäüÖÄÜ,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE)
def cleanup(text):
    text = clean_chars.sub('', text)
    #print("bug: somehow all numbers are removed - this is might be due to thisregex")
    #exit()
    #text = text.replace("\n", "")
    #text = text.replace('"','\\"')
    return text

with open(file1name, "r") as file1, open(file2name, "r") as file2:
    for line1, line2 in zip(file1, file2):

        line1 = cleanup(line1)

        # for actual spelling correction evaluation:
        
        intermediate=(fix_spelling(line1,max_length=2048))
        line=intermediate[0]['generated_text'];

        # for lower-bound testing on the errors:
        #line = line1
        
        print(line1)
        print(line)

        line2 = cleanup(line2)
        print(line2)
        
        if len(line)>0 and len(line2)>0:
            predictions.append(line)
            references.append(line2)

        if counter%100==0:
            print(counter)
            cer_score = cer.compute(predictions=predictions, references=references)
            print('CER    - ' + str(cer_score))
            wer_score = wer.compute(predictions=predictions, references=references)
            print('WER    - ' + str(wer_score))
            bleu_score = bleu.compute(predictions=predictions, references=references)
            print('BLEU   - ' + str(bleu_score))
            meteor_score = meteor.compute(predictions=predictions, references=references)
            print('METEOR - ' + str(meteor_score))
            
        counter+=1
            
cer_score = cer.compute(predictions=predictions, references=references)
print('CER    - ' + str(cer_score))
wer_score = wer.compute(predictions=predictions, references=references)
print('WER    - ' + str(wer_score))
bleu_score = bleu.compute(predictions=predictions, references=references)
print('BLEU   - ' + str(bleu_score))
meteor_score = meteor.compute(predictions=predictions, references=references)
print('METEOR - ' + str(meteor_score))