import pandas as pd import os import re import csv def extract_paren(annotation): ents = [] for i in range(len(annotation)): if annotation[i] == "[": ent = "[" open_paren = 0 for j in range(i+1, len(annotation)): if annotation[j] == "[": open_paren += 1 elif annotation[j] == "]": if open_paren > 0: open_paren -= 1 ent = ent[:len(ent)-3] else: ent += "]" digit = re.search(r": [0-9]{1,3}", ent) if digit: matches = re.findall(r": [0-9]{1,3}", annotation[:i]) str_index = annotation[:i].count(" ") - len(matches) ent += "|" + str(str_index) ents.append(ent) break else: ent += annotation[j] return ents def create_clusters(ents): clusters = {} for e in ents: digit_ann = re.search(r": [0-9]{1,3}", e) if digit_ann: clean_e = e.replace("[", "").replace("]", "").replace(digit_ann.group(), "") digit = re.search(r"[0-9]{1,3}", digit_ann.group()) digit = int(digit.group()) if digit not in clusters: clusters[digit] = [] clusters[digit].append(clean_e) else: print("OH NO:", e) print() return clusters headers = ["input", "model_output", "model_output_clusters"] df = pd.read_csv("results.csv") rows = [] for index, row in df.iterrows(): annotation = row["model_output"] if isinstance(annotation, str): ann_ents = extract_paren(annotation) ann_clusters = {} if ann_ents: ann_clusters = create_clusters(ann_ents) else: ann_clusters = {} new_row = [row["input"], annotation, str(ann_clusters)] rows.append(new_row) f = open("cluster_results.csv", "w") writer = csv.writer(f) writer.writerow(headers) writer.writerows(rows) f.close()