import argparse import os from difflib import SequenceMatcher import Levenshtein import numpy as np from tqdm import tqdm from helpers import write_lines, read_parallel_lines, encode_verb_form, \ apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN def perfect_align(t, T, insertions_allowed=0, cost_function=Levenshtein.distance): # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with # first `j` tokens of `T`, after making `k` insertions after last match of # token from `t`. In other words t[:i] aligned with T[:j]. # Initialize with INFINITY (unknown) shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1) dp = np.ones(shape, dtype=int) * int(1e9) come_from = np.ones(shape, dtype=int) * int(1e9) come_from_ins = np.ones(shape, dtype=int) * int(1e9) dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing. for i in range(len(t) + 1): # Go inclusive for j in range(len(T) + 1): # Go inclusive for q in range(insertions_allowed + 1): # Go inclusive if i < len(t): # Given matched sequence of t[:i] and T[:j], match token # t[i] with following tokens T[j:k]. for k in range(j, len(T) + 1): transform = \ apply_transformation(t[i], ' '.join(T[j:k])) if transform: cost = 0 else: cost = cost_function(t[i], ' '.join(T[j:k])) current = dp[i, j, q] + cost if dp[i + 1, k, 0] > current: dp[i + 1, k, 0] = current come_from[i + 1, k, 0] = j come_from_ins[i + 1, k, 0] = q if q < insertions_allowed: # Given matched sequence of t[:i] and T[:j], create # insertion with following tokens T[j:k]. for k in range(j, len(T) + 1): cost = len(' '.join(T[j:k])) current = dp[i, j, q] + cost if dp[i, k, q + 1] > current: dp[i, k, q + 1] = current come_from[i, k, q + 1] = j come_from_ins[i, k, q + 1] = q # Solution is in the dp[len(t), len(T), *]. Backtracking from there. alignment = [] i = len(t) j = len(T) q = dp[i, j, :].argmin() while i > 0 or q > 0: is_insert = (come_from_ins[i, j, q] != q) and (q != 0) j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q] if not is_insert: i -= 1 if is_insert: alignment.append(['INSERT', T[j:k], (i, i)]) else: alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)]) assert j == 0 return dp[len(t), len(T)].min(), list(reversed(alignment)) def _split(token): if not token: return [] parts = token.split() return parts or [token] def apply_merge_transformation(source_tokens, target_words, shift_idx): edits = [] if len(source_tokens) > 1 and len(target_words) == 1: # check merge transform = check_merge(source_tokens, target_words) if transform: for i in range(len(source_tokens) - 1): edits.append([(shift_idx + i, shift_idx + i + 1), transform]) return edits if len(source_tokens) == len(target_words) == 2: # check swap transform = check_swap(source_tokens, target_words) if transform: edits.append([(shift_idx, shift_idx + 1), transform]) return edits def is_sent_ok(sent, delimeters=SEQ_DELIMETERS): for del_val in delimeters.values(): if del_val in sent and del_val != delimeters["tokens"]: return False return True def check_casetype(source_token, target_token): if source_token.lower() != target_token.lower(): return None if source_token.lower() == target_token: return "$TRANSFORM_CASE_LOWER" elif source_token.capitalize() == target_token: return "$TRANSFORM_CASE_CAPITAL" elif source_token.upper() == target_token: return "$TRANSFORM_CASE_UPPER" elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]: return "$TRANSFORM_CASE_CAPITAL_1" elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]: return "$TRANSFORM_CASE_UPPER_-1" else: return None def check_equal(source_token, target_token): if source_token == target_token: return "$KEEP" else: return None def check_split(source_token, target_tokens): if source_token.split("-") == target_tokens: return "$TRANSFORM_SPLIT_HYPHEN" else: return None def check_merge(source_tokens, target_tokens): if "".join(source_tokens) == "".join(target_tokens): return "$MERGE_SPACE" elif "-".join(source_tokens) == "-".join(target_tokens): return "$MERGE_HYPHEN" else: return None def check_swap(source_tokens, target_tokens): if source_tokens == [x for x in reversed(target_tokens)]: return "$MERGE_SWAP" else: return None def check_plural(source_token, target_token): if source_token.endswith("s") and source_token[:-1] == target_token: return "$TRANSFORM_AGREEMENT_SINGULAR" elif target_token.endswith("s") and source_token == target_token[:-1]: return "$TRANSFORM_AGREEMENT_PLURAL" else: return None def check_verb(source_token, target_token): encoding = encode_verb_form(source_token, target_token) if encoding: return f"$TRANSFORM_VERB_{encoding}" else: return None def apply_transformation(source_token, target_token): target_tokens = target_token.split() if len(target_tokens) > 1: # check split transform = check_split(source_token, target_tokens) if transform: return transform checks = [check_equal, check_casetype, check_verb, check_plural] for check in checks: transform = check(source_token, target_token) if transform: return transform return None def align_sequences(source_sent, target_sent): # check if sent is OK if not is_sent_ok(source_sent) or not is_sent_ok(target_sent): return None source_tokens = source_sent.split() target_tokens = target_sent.split() matcher = SequenceMatcher(None, source_tokens, target_tokens) diffs = list(matcher.get_opcodes()) all_edits = [] for diff in diffs: tag, i1, i2, j1, j2 = diff source_part = _split(" ".join(source_tokens[i1:i2])) target_part = _split(" ".join(target_tokens[j1:j2])) if tag == 'equal': continue elif tag == 'delete': # delete all words separatly for j in range(i2 - i1): edit = [(i1 + j, i1 + j + 1), '$DELETE'] all_edits.append(edit) elif tag == 'insert': # append to the previous word for target_token in target_part: edit = ((i1 - 1, i1), f"$APPEND_{target_token}") all_edits.append(edit) else: # check merge first of all edits = apply_merge_transformation(source_part, target_part, shift_idx=i1) if edits: all_edits.extend(edits) continue # normalize alignments if need (make them singleton) _, alignments = perfect_align(source_part, target_part, insertions_allowed=0) for alignment in alignments: new_shift = alignment[2][0] edits = convert_alignments_into_edits(alignment, shift_idx=i1 + new_shift) all_edits.extend(edits) # get labels labels = convert_edits_into_labels(source_tokens, all_edits) # match tags to source tokens sent_with_tags = add_labels_to_the_tokens(source_tokens, labels) return sent_with_tags def convert_edits_into_labels(source_tokens, all_edits): # make sure that edits are flat flat_edits = [] for edit in all_edits: (start, end), edit_operations = edit if isinstance(edit_operations, list): for operation in edit_operations: new_edit = [(start, end), operation] flat_edits.append(new_edit) elif isinstance(edit_operations, str): flat_edits.append(edit) else: raise Exception("Unknown operation type") all_edits = flat_edits[:] labels = [] total_labels = len(source_tokens) + 1 if not all_edits: labels = [["$KEEP"] for x in range(total_labels)] else: for i in range(total_labels): edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1 and x[0][1] == i] if not edit_operations: labels.append(["$KEEP"]) else: labels.append(edit_operations) return labels def convert_alignments_into_edits(alignment, shift_idx): edits = [] action, target_tokens, new_idx = alignment source_token = action.replace("REPLACE_", "") # check if delete if not target_tokens: edit = [(shift_idx, 1 + shift_idx), "$DELETE"] return [edit] # check splits for i in range(1, len(target_tokens)): target_token = " ".join(target_tokens[:i + 1]) transform = apply_transformation(source_token, target_token) if transform: edit = [(shift_idx, shift_idx + 1), transform] edits.append(edit) target_tokens = target_tokens[i + 1:] for target in target_tokens: edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"]) return edits transform_costs = [] transforms = [] for target_token in target_tokens: transform = apply_transformation(source_token, target_token) if transform: cost = 0 transforms.append(transform) else: cost = Levenshtein.distance(source_token, target_token) transforms.append(None) transform_costs.append(cost) min_cost_idx = transform_costs.index(min(transform_costs)) # append to the previous word for i in range(0, min_cost_idx): target = target_tokens[i] edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"] edits.append(edit) # replace/transform target word transform = transforms[min_cost_idx] target = transform if transform is not None \ else f"$REPLACE_{target_tokens[min_cost_idx]}" edit = [(shift_idx, 1 + shift_idx), target] edits.append(edit) # append to this word for i in range(min_cost_idx + 1, len(target_tokens)): target = target_tokens[i] edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"] edits.append(edit) return edits def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS): tokens_with_all_tags = [] source_tokens_with_start = [START_TOKEN] + source_tokens for token, label_list in zip(source_tokens_with_start, labels): all_tags = delimeters['operations'].join(label_list) comb_record = token + delimeters['labels'] + all_tags tokens_with_all_tags.append(comb_record) return delimeters['tokens'].join(tokens_with_all_tags) def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size): tagged = [] source_data, target_data = read_parallel_lines(source_file, target_file) print(f"The size of raw dataset is {len(source_data)}") cnt_total, cnt_all, cnt_tp = 0, 0, 0 for source_sent, target_sent in tqdm(zip(source_data, target_data)): try: aligned_sent = align_sequences(source_sent, target_sent) except Exception: aligned_sent = align_sequences(source_sent, target_sent) if source_sent != target_sent: cnt_tp += 1 alignments = [aligned_sent] cnt_all += len(alignments) try: check_sent = convert_tagged_line(aligned_sent) except Exception: # debug mode aligned_sent = align_sequences(source_sent, target_sent) check_sent = convert_tagged_line(aligned_sent) if "".join(check_sent.split()) != "".join( target_sent.split()): # do it again for debugging aligned_sent = align_sequences(source_sent, target_sent) check_sent = convert_tagged_line(aligned_sent) print(f"Incorrect pair: \n{target_sent}\n{check_sent}") continue if alignments: cnt_total += len(alignments) tagged.extend(alignments) if len(tagged) > chunk_size: write_lines(output_file, tagged, mode='a') tagged = [] print(f"Overall extracted {cnt_total}. " f"Original TP {cnt_tp}." f" Original TN {cnt_all - cnt_tp}") if tagged: write_lines(output_file, tagged, 'a') def convert_labels_into_edits(labels): all_edits = [] for i, label_list in enumerate(labels): if label_list == ["$KEEP"]: continue else: edit = [(i - 1, i), label_list] all_edits.append(edit) return all_edits def get_target_sent_by_levels(source_tokens, labels): relevant_edits = convert_labels_into_edits(labels) target_tokens = source_tokens[:] leveled_target_tokens = {} if not relevant_edits: target_sentence = " ".join(target_tokens) return leveled_target_tokens, target_sentence max_level = max([len(x[1]) for x in relevant_edits]) for level in range(max_level): rest_edits = [] shift_idx = 0 for edits in relevant_edits: (start, end), label_list = edits label = label_list[0] target_pos = start + shift_idx source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN if label == "$DELETE": del target_tokens[target_pos] shift_idx -= 1 elif label.startswith("$APPEND_"): word = label.replace("$APPEND_", "") target_tokens[target_pos + 1: target_pos + 1] = [word] shift_idx += 1 elif label.startswith("$REPLACE_"): word = label.replace("$REPLACE_", "") target_tokens[target_pos] = word elif label.startswith("$TRANSFORM"): word = apply_reverse_transformation(source_token, label) if word is None: word = source_token target_tokens[target_pos] = word elif label.startswith("$MERGE_"): # apply merge only on last stage if level == (max_level - 1): target_tokens[target_pos + 1: target_pos + 1] = [label] shift_idx += 1 else: rest_edit = [(start + shift_idx, end + shift_idx), [label]] rest_edits.append(rest_edit) rest_labels = label_list[1:] if rest_labels: rest_edit = [(start + shift_idx, end + shift_idx), rest_labels] rest_edits.append(rest_edit) leveled_tokens = target_tokens[:] # update next step relevant_edits = rest_edits[:] if level == (max_level - 1): leveled_tokens = replace_merge_transforms(leveled_tokens) leveled_labels = convert_edits_into_labels(leveled_tokens, relevant_edits) leveled_target_tokens[level + 1] = {"tokens": leveled_tokens, "labels": leveled_labels} target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"]) return leveled_target_tokens, target_sentence def replace_merge_transforms(tokens): if all(not x.startswith("$MERGE_") for x in tokens): return tokens target_tokens = tokens[:] allowed_range = (1, len(tokens) - 1) for i in range(len(tokens)): target_token = tokens[i] if target_token.startswith("$MERGE"): if target_token.startswith("$MERGE_SWAP") and i in allowed_range: target_tokens[i - 1] = tokens[i + 1] target_tokens[i + 1] = tokens[i - 1] target_tokens[i: i + 1] = [] target_line = " ".join(target_tokens) target_line = target_line.replace(" $MERGE_HYPHEN ", "-") target_line = target_line.replace(" $MERGE_SPACE ", "") return target_line.split() def convert_tagged_line(line, delimeters=SEQ_DELIMETERS): label_del = delimeters['labels'] source_tokens = [x.split(label_del)[0] for x in line.split(delimeters['tokens'])][1:] labels = [x.split(label_del)[1].split(delimeters['operations']) for x in line.split(delimeters['tokens'])] assert len(source_tokens) + 1 == len(labels) levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels) return target_line def main(args): convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-s', '--source', help='Path to the source file', required=True) parser.add_argument('-t', '--target', help='Path to the target file', required=True) parser.add_argument('-o', '--output_file', help='Path to the output file', required=True) parser.add_argument('--chunk_size', type=int, help='Dump each chunk size.', default=1000000) args = parser.parse_args() main(args)