from transformers import pipeline, BartForConditionalGeneration, AutoTokenizer from evaluate import load import re model = BartForConditionalGeneration.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3') tokenizer = AutoTokenizer.from_pretrained('/home/antalb/software/spelling/bart-base-spelling-nl-9m-3') fix_spelling = pipeline("text2text-generation",model=model,tokenizer=tokenizer) cer = load("cer") wer = load("wer") bleu = load("bleu") meteor = load("meteor") file1name = 'opentaal-annotaties.txt.errors' file2name = 'opentaal-annotaties.txt.corrections' predictions=[] references=[] counter=0; #clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE) clean_chars = re.compile(r'[^A-Za-zëïöäüÖÄÜ,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE) def cleanup(text): text = clean_chars.sub('', text) #print("bug: somehow all numbers are removed - this is might be due to thisregex") #exit() #text = text.replace("\n", "") #text = text.replace('"','\\"') return text with open(file1name, "r") as file1, open(file2name, "r") as file2: for line1, line2 in zip(file1, file2): line1 = cleanup(line1) # for actual spelling correction evaluation: intermediate=(fix_spelling(line1,max_length=2048)) line=intermediate[0]['generated_text']; # for lower-bound testing on the errors: #line = line1 print(line1) print(line) line2 = cleanup(line2) print(line2) if len(line)>0 and len(line2)>0: predictions.append(line) references.append(line2) if counter%100==0: print(counter) cer_score = cer.compute(predictions=predictions, references=references) print('CER - ' + str(cer_score)) wer_score = wer.compute(predictions=predictions, references=references) print('WER - ' + str(wer_score)) bleu_score = bleu.compute(predictions=predictions, references=references) print('BLEU - ' + str(bleu_score)) meteor_score = meteor.compute(predictions=predictions, references=references) print('METEOR - ' + str(meteor_score)) counter+=1 cer_score = cer.compute(predictions=predictions, references=references) print('CER - ' + str(cer_score)) wer_score = wer.compute(predictions=predictions, references=references) print('WER - ' + str(wer_score)) bleu_score = bleu.compute(predictions=predictions, references=references) print('BLEU - ' + str(bleu_score)) meteor_score = meteor.compute(predictions=predictions, references=references) print('METEOR - ' + str(meteor_score))