hugpv commited on
Commit
8e5930e
1 Parent(s): 0e09020

initial commit via hf

Browse files
Files changed (27) hide show
  1. .gitignore +2 -0
  2. algo_cfgs_all.json +51 -0
  3. analysis_funcs.py +355 -0
  4. app.py +1453 -0
  5. classic_correction_algos.py +546 -0
  6. eyekit_measures.py +178 -0
  7. loss_functions.py +179 -0
  8. models.py +897 -0
  9. models/BERT_20240104-223349_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00430.ckpt +3 -0
  10. models/BERT_20240104-233803_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00719.ckpt +3 -0
  11. models/BERT_20240107-152040_loop_restrict_sim_data_to_4000_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00515.ckpt +3 -0
  12. models/BERT_20240108-000344_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00706.ckpt +3 -0
  13. models/BERT_20240108-011230_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00560.ckpt +3 -0
  14. models/BERT_20240109-090419_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00518.ckpt +3 -0
  15. models/BERT_20240122-183729_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00523.ckpt +3 -0
  16. models/BERT_20240122-194041_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00462.ckpt +3 -0
  17. models/BERT_fin_exp_20240104-223349.yaml +100 -0
  18. models/BERT_fin_exp_20240104-233803.yaml +100 -0
  19. models/BERT_fin_exp_20240107-152040.yaml +100 -0
  20. models/BERT_fin_exp_20240108-000344.yaml +100 -0
  21. models/BERT_fin_exp_20240108-011230.yaml +100 -0
  22. models/BERT_fin_exp_20240109-090419.yaml +100 -0
  23. models/BERT_fin_exp_20240122-183729.yaml +102 -0
  24. models/BERT_fin_exp_20240122-194041.yaml +102 -0
  25. requirements.txt +25 -0
  26. run_in_notebook.ipynb +0 -0
  27. utils.py +2016 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ .gitignore
algo_cfgs_all.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compare": {
3
+ "x_thresh": 512,
4
+ "n_nearest_lines": 3
5
+ },
6
+ "attach": {},
7
+ "segment": {},
8
+ "split": {},
9
+ "stretch": {
10
+ "stretch_bounds": [
11
+ 0.9,
12
+ 1.1
13
+ ],
14
+ "offset_bounds": [
15
+ -50,
16
+ 50
17
+ ]
18
+ },
19
+ "slice": {
20
+ "x_thresh": 192,
21
+ "y_thresh": 32,
22
+ "w_thresh": 32,
23
+ "n_thresh": 90
24
+
25
+ },
26
+ "warp": {},
27
+ "chain": {
28
+ "x_thresh": 192,
29
+ "y_thresh": 55
30
+ },
31
+ "regress": {
32
+ "slope_bounds": [
33
+ -0.1,
34
+ 0.1
35
+ ],
36
+ "offset_bounds": [
37
+ -50,
38
+ 50
39
+ ],
40
+ "std_bounds": [
41
+ 1,
42
+ 20
43
+ ]
44
+ },
45
+ "cluster": {},
46
+ "merge": {
47
+ "y_thresh": 32,
48
+ "gradient_thresh": 0.1,
49
+ "error_thresh": 20
50
+ }
51
+ }
analysis_funcs.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Partially taken and adapted from: https://github.com/jwcarr/eyekit/blob/1db1913411327b108b87e097a00278b6e50d0751/eyekit/measure.py
3
+ Functions for calculating common reading measures, such as gaze duration or
4
+ initial landing position.
5
+ """
6
+
7
+ import pandas as pd
8
+
9
+
10
+ def fix_in_ia(fix_x, fix_y, ia_x_min, ia_x_max, ia_y_min, ia_y_max):
11
+ in_x = ia_x_min <= fix_x <= ia_x_max
12
+ in_y = ia_y_min <= fix_y <= ia_y_max
13
+ if in_x and in_y:
14
+ return True
15
+ else:
16
+ return False
17
+
18
+
19
+ def fix_in_ia_default(fixation, ia_row, prefix):
20
+ return fix_in_ia(
21
+ fixation.x,
22
+ fixation.y,
23
+ ia_row[f"{prefix}_xmin"],
24
+ ia_row[f"{prefix}_xmax"],
25
+ ia_row[f"{prefix}_ymin"],
26
+ ia_row[f"{prefix}_ymax"],
27
+ )
28
+
29
+
30
+ def number_of_fixations_own(trial, dffix, prefix="word"):
31
+ """
32
+ Given an interest area and fixation sequence, return the number of
33
+ fixations on that interest area.
34
+ """
35
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
36
+ counts = []
37
+ for cidx, ia_row in ia_df.iterrows():
38
+ count = 0
39
+ for idx, fixation in dffix.iterrows():
40
+ if fix_in_ia(
41
+ fixation.x,
42
+ fixation.y,
43
+ ia_row[f"{prefix}_xmin"],
44
+ ia_row[f"{prefix}_xmax"],
45
+ ia_row[f"{prefix}_ymin"],
46
+ ia_row[f"{prefix}_ymax"],
47
+ ):
48
+ count += 1
49
+ counts.append(
50
+ {
51
+ f"{prefix}_index": cidx,
52
+ prefix: ia_row[f"{prefix}"],
53
+ "number_of_fixations": count,
54
+ }
55
+ )
56
+ return pd.DataFrame(counts)
57
+
58
+
59
+ def initial_fixation_duration_own(trial, dffix, prefix="word"):
60
+ """
61
+ Given an interest area and fixation sequence, return the duration of the
62
+ initial fixation on that interest area for each word.
63
+ """
64
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
65
+ durations = []
66
+
67
+ for cidx, ia_row in ia_df.iterrows():
68
+ initial_duration = 0
69
+ for idx, fixation in dffix.iterrows():
70
+ if fix_in_ia_default(fixation, ia_row, prefix):
71
+ initial_duration = fixation.duration
72
+ break # Exit the loop after finding the initial fixation for the word
73
+ durations.append(
74
+ {
75
+ f"{prefix}_index": cidx,
76
+ prefix: ia_row[f"{prefix}"],
77
+ "initial_fixation_duration": initial_duration,
78
+ }
79
+ )
80
+
81
+ return pd.DataFrame(durations)
82
+
83
+
84
+ def first_of_many_duration_own(trial, dffix, prefix="word"):
85
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
86
+ durations = []
87
+ for cidx, ia_row in ia_df.iterrows():
88
+ fixation_durations = []
89
+ for idx, fixation in dffix.iterrows():
90
+ if fix_in_ia_default(fixation, ia_row, prefix):
91
+ fixation_durations.append(fixation.duration)
92
+ if len(fixation_durations) > 1:
93
+ durations.append(
94
+ {
95
+ f"{prefix}_index": cidx,
96
+ prefix: ia_row[f"{prefix}"],
97
+ "first_of_many_duration": fixation_durations[0],
98
+ }
99
+ )
100
+ if durations:
101
+ return pd.DataFrame(durations)
102
+ else:
103
+ return pd.DataFrame()
104
+
105
+
106
+ def total_fixation_duration_own(trial, dffix, prefix="word"):
107
+ """
108
+ Given an interest area and fixation sequence, return the sum duration of
109
+ all fixations on that interest area.
110
+ """
111
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
112
+ durations = []
113
+ for cidx, ia_row in ia_df.iterrows():
114
+ total_duration = 0
115
+ for idx, fixation in dffix.iterrows():
116
+ if fix_in_ia_default(fixation, ia_row, prefix):
117
+ total_duration += fixation.duration
118
+ durations.append(
119
+ {
120
+ f"{prefix}_index": cidx,
121
+ prefix: ia_row[f"{prefix}"],
122
+ "total_fixation_duration": total_duration,
123
+ }
124
+ )
125
+ return pd.DataFrame(durations)
126
+
127
+
128
+ def gaze_duration_own(trial, dffix, prefix="word"):
129
+ """
130
+ Given an interest area and fixation sequence, return the gaze duration on
131
+ that interest area. Gaze duration is the sum duration of all fixations
132
+ inside an interest area until the area is exited for the first time.
133
+ """
134
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
135
+ durations = []
136
+ for cidx, ia_row in ia_df.iterrows():
137
+ duration = 0
138
+ in_ia = False
139
+ for idx, fixation in dffix.iterrows():
140
+ if fix_in_ia_default(fixation, ia_row, prefix):
141
+ duration += fixation.duration
142
+ in_ia = True
143
+ elif in_ia:
144
+ break
145
+ durations.append(
146
+ {
147
+ f"{prefix}_index": cidx,
148
+ prefix: ia_row[f"{prefix}"],
149
+ "gaze_duration": duration,
150
+ }
151
+ )
152
+ return pd.DataFrame(durations)
153
+
154
+
155
+ def go_past_duration_own(trial, dffix, prefix="word"):
156
+ """
157
+ Given an interest area and fixation sequence, return the go-past time on
158
+ that interest area. Go-past time is the sum duration of all fixations from
159
+ when the interest area is first entered until when it is first exited to
160
+ the right, including any regressions to the left that occur during that
161
+ time period (and vice versa in the case of right-to-left text).
162
+ """
163
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
164
+ results = []
165
+
166
+ for cidx, ia_row in ia_df.iterrows():
167
+ entered = False
168
+ go_past_time = 0
169
+
170
+ for idx, fixation in dffix.iterrows():
171
+ if fix_in_ia_default(fixation, ia_row, prefix):
172
+ if not entered:
173
+ entered = True
174
+ go_past_time += fixation.duration
175
+ elif entered:
176
+ if ia_row[f"{prefix}_xmax"] < fixation.x: # Interest area has been exited to the right
177
+ break
178
+ go_past_time += fixation.duration
179
+
180
+ results.append({f"{prefix}_index": cidx, prefix: ia_row[f"{prefix}"], "go_past_duration": go_past_time})
181
+
182
+ return pd.DataFrame(results)
183
+
184
+
185
+ def second_pass_duration_own(trial, dffix, prefix="word"):
186
+ """
187
+ Given an interest area and fixation sequence, return the second pass
188
+ duration on that interest area for each word.
189
+ """
190
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
191
+ durations = []
192
+
193
+ for cidx, ia_row in ia_df.iterrows():
194
+ current_pass = None
195
+ next_pass = 1
196
+ pass_duration = 0
197
+ for idx, fixation in dffix.iterrows():
198
+ if fix_in_ia_default(fixation, ia_row, prefix):
199
+ if current_pass is None: # first fixation in a new pass
200
+ current_pass = next_pass
201
+ if current_pass == 2:
202
+ pass_duration += fixation.duration
203
+ elif current_pass == 1: # first fixation to exit the first pass
204
+ current_pass = None
205
+ next_pass += 1
206
+ elif current_pass == 2: # first fixation to exit the second pass
207
+ break
208
+ durations.append(
209
+ {
210
+ f"{prefix}_index": cidx,
211
+ prefix: ia_row[f"{prefix}"],
212
+ "second_pass_duration": pass_duration,
213
+ }
214
+ )
215
+
216
+ return pd.DataFrame(durations)
217
+
218
+
219
+ def initial_landing_position_own(trial, dffix, prefix="word"):
220
+ """
221
+ Given an interest area and fixation sequence, return the initial landing
222
+ position (expressed in character positions) on that interest area.
223
+ Counting is from 1. If the interest area represents right-to-left text,
224
+ the first character is the rightmost one. Returns `None` if no fixation
225
+ landed on the interest area.
226
+ """
227
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
228
+ if prefix == "word":
229
+ chars_df = pd.DataFrame(trial[f"chars_list"])
230
+ else:
231
+ chars_df = None
232
+ results = []
233
+ for cidx, ia_row in ia_df.iterrows():
234
+ landing_position = None
235
+ for idx, fixation in dffix.iterrows():
236
+ if fix_in_ia_default(fixation, ia_row, prefix):
237
+ if prefix == "char":
238
+ landing_position = 1
239
+ else:
240
+ prefix_temp = "char"
241
+ matched_chars_df = chars_df.loc[
242
+ (chars_df.char_xmin >= ia_row[f"{prefix}_xmin"])
243
+ & (chars_df.char_xmax <= ia_row[f"{prefix}_xmax"])
244
+ & (chars_df.char_ymin >= ia_row[f"{prefix}_ymin"])
245
+ & (chars_df.char_ymax <= ia_row[f"{prefix}_ymax"]),
246
+ :,
247
+ ] # need to find way to count correct letter number
248
+ for char_idx, (rowidx, char_row) in enumerate(matched_chars_df.iterrows()):
249
+ if fix_in_ia_default(fixation, char_row, prefix_temp):
250
+ landing_position = char_idx + 1 # starts at 1
251
+ break
252
+ break
253
+ results.append(
254
+ {
255
+ f"{prefix}_index": cidx,
256
+ prefix: ia_row[f"{prefix}"],
257
+ "initial_landing_position": landing_position,
258
+ }
259
+ )
260
+ return pd.DataFrame(results)
261
+
262
+
263
+ def initial_landing_distance_own(trial, dffix, prefix="word"):
264
+ """
265
+ Given an interest area and fixation sequence, return the initial landing
266
+ distance on that interest area. The initial landing distance is the pixel
267
+ distance between the first fixation to land in an interest area and the
268
+ left edge of that interest area (or, in the case of right-to-left text,
269
+ the right edge). Technically, the distance is measured from the text onset
270
+ without including any padding. Returns `None` if no fixation landed on the
271
+ interest area.
272
+ """
273
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
274
+ distances = []
275
+ for cidx, ia_row in ia_df.iterrows():
276
+ initial_distance = None
277
+ for idx, fixation in dffix.iterrows():
278
+ if fix_in_ia_default(fixation, ia_row, prefix):
279
+ distance = abs(ia_row[f"{prefix}_xmin"] - fixation.x)
280
+ if initial_distance is None:
281
+ initial_distance = distance
282
+ break
283
+ distances.append(
284
+ {
285
+ f"{prefix}_index": cidx,
286
+ prefix: ia_row[f"{prefix}"],
287
+ "initial_landing_distance": initial_distance,
288
+ }
289
+ )
290
+ return pd.DataFrame(distances)
291
+
292
+
293
+ def landing_distances_own(trial, dffix, prefix="word"):
294
+ """
295
+ Given an interest area and fixation sequence, return a dataframe with
296
+ landing distances for each word in the interest area.
297
+ """
298
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
299
+ distances = []
300
+ for cidx, ia_row in ia_df.iterrows():
301
+ landing_distances = []
302
+ for idx, fixation in dffix.iterrows():
303
+ if fix_in_ia_default(fixation, ia_row, prefix):
304
+ landing_distance = abs(ia_row[f"{prefix}_xmin"] - fixation.x)
305
+ landing_distances.append(round(landing_distance, ndigits=2))
306
+ distances.append({f"{prefix}_index": cidx, prefix: ia_row[f"{prefix}"], "landing_distances": landing_distances})
307
+ return pd.DataFrame(distances)
308
+
309
+
310
+ def number_of_regressions_in_own(trial, dffix, prefix="word"):
311
+ """
312
+ Given an interest area and fixation sequence, return the number of
313
+ regressions back to that interest area after the interest area was read
314
+ for the first time. In other words, find the first fixation to exit the
315
+ interest area and then count how many times the reader returns to the
316
+ interest area from the right (or from the left in the case of
317
+ right-to-left text).
318
+ """
319
+ ia_df = pd.DataFrame(trial[f"{prefix}s_list"])
320
+ counts = []
321
+ for cidx, ia_row in ia_df.iterrows():
322
+ entered_interest_area = False
323
+ first_exit_index = None
324
+ count = 0
325
+ prev_fixation = None
326
+ regression_counted = False
327
+
328
+ for fixidx, (rowidx, fixation) in enumerate(dffix.iterrows()):
329
+ if (
330
+ entered_interest_area
331
+ and first_exit_index is not None
332
+ and fix_in_ia_default(fixation, ia_row, prefix)
333
+ and not regression_counted
334
+ ):
335
+ if prev_fixation.x > fixation.x:
336
+ count += 1
337
+ regression_counted = True
338
+
339
+ if fix_in_ia_default(fixation, ia_row, prefix):
340
+ entered_interest_area = True
341
+ elif entered_interest_area and first_exit_index is None:
342
+ first_exit_index = fixidx
343
+ else:
344
+ regression_counted = False
345
+ prev_fixation = fixation
346
+
347
+ counts.append(
348
+ {
349
+ f"{prefix}_index": cidx,
350
+ prefix: ia_row[f"{prefix}"],
351
+ "number_of_regressions_in": count,
352
+ }
353
+ )
354
+
355
+ return pd.DataFrame(counts)
app.py ADDED
@@ -0,0 +1,1453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from PIL import Image
3
+ from io import StringIO
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import numpy as np
7
+ import re
8
+ import time
9
+ import os
10
+
11
+ from matplotlib.font_manager import FontProperties
12
+ from matplotlib.patches import Rectangle
13
+ from matplotlib import pyplot as plt
14
+ import plotly.graph_objects as go
15
+ import plotly.express as px
16
+ import numpy as np
17
+ import pandas as pd
18
+ import pathlib as pl
19
+ import json
20
+ import logging
21
+ import zipfile
22
+ from stqdm import stqdm
23
+ import jellyfish as jf
24
+ import lovely_tensors
25
+ import shutil
26
+ import eyekit_measures as ekm
27
+ import zipfile
28
+
29
+ import utils as ut
30
+
31
+ os.environ["MPLCONFIGDIR"] = os.getcwd() + "/configs/"
32
+
33
+ st.set_page_config("Correction", page_icon=":eye:", layout="wide")
34
+
35
+ AVAILABLE_FONTS = st.session_state["AVAILABLE_FONTS"] = ut.AVAILABLE_FONTS
36
+
37
+ DEFAULT_PLOT_FONT = "DejaVu Sans Mono"
38
+ EXAMPLES_FOLDER = "./testfiles/"
39
+ EXAMPLES_ASC_ZIP_FILENAME = "asc_files.zip"
40
+ OSF_DOWNLAOD_LINK = "https://osf.io/download/us97f/"
41
+ EXAMPLES_FOLDER_PATH = pl.Path(EXAMPLES_FOLDER)
42
+
43
+
44
+ lovely_tensors.monkey_patch()
45
+
46
+
47
+ def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots):
48
+ return ut.make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots)
49
+
50
+
51
+ TEMP_FOLDER = st.session_state["TEMP_FOLDER"] = ut.TEMP_FOLDER
52
+ gradio_temp_unzipped_folder = st.session_state["gradio_temp_unzipped_folder"] = pl.Path("unzipped")
53
+
54
+ PLOTS_FOLDER = st.session_state["PLOTS_FOLDER"] = pl.Path("plots")
55
+ TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER.joinpath("temp_matplotlib_plot_stimulus.png")
56
+ make_folders(TEMP_FOLDER, gradio_temp_unzipped_folder, PLOTS_FOLDER)
57
+
58
+
59
+ @st.cache_data
60
+ def get_classic_cfg(fname):
61
+ return ut.get_classic_cfg(fname)
62
+
63
+
64
+ classic_algos_cfg = st.session_state["classic_algos_cfg"] = get_classic_cfg("algo_cfgs_all.json")
65
+
66
+ DIST_MODELS_FOLDER = st.session_state["DIST_MODELS_FOLDER"] = pl.Path("models")
67
+ COLORS = st.session_state["COLORS"] = px.colors.qualitative.Alphabet
68
+ ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [
69
+ "warp",
70
+ "regress",
71
+ "compare",
72
+ "attach",
73
+ "segment",
74
+ "split",
75
+ "stretch",
76
+ "chain",
77
+ "slice",
78
+ "cluster",
79
+ "merge",
80
+ "Wisdom_of_Crowds",
81
+ "DIST",
82
+ "DIST-Ensemble",
83
+ "Wisdom_of_Crowds_with_DIST",
84
+ "Wisdom_of_Crowds_with_DIST_Ensemble",
85
+ ]
86
+
87
+
88
+ st.session_state["colnames_custom_csv_fix"] = {
89
+ "x_col_name_fix": "x",
90
+ "y_col_name_fix": "y",
91
+ "x_col_name_fix_stim": "char_x_center",
92
+ "x_start_col_name_fix_stim": "char_xmin",
93
+ "x_end_col_name_fix_stim": "char_xmax",
94
+ "y_col_name_fix_stim": "char_y_center",
95
+ "y_start_col_name_fix_stim": "char_ymin",
96
+ "y_end_col_name_fix_stim": "char_ymax",
97
+ "char_col_name_fix_stim": "char",
98
+ "trial_id_col_name_fix": "trial_id",
99
+ "trial_id_col_name_stim": "trial_id",
100
+ "subject_col_name_fix": "subid",
101
+ "subject_col_name_stim": "subid",
102
+ "line_num_col_name_stim": "assigned_line",
103
+ "time_start_col_name_fix": "start",
104
+ "time_stop_col_name_fix": "stop",
105
+ }
106
+
107
+ if "results" not in st.session_state:
108
+ st.session_state["results"] = {}
109
+
110
+
111
+ @st.cache_resource
112
+ def load_model(model_file, cfg):
113
+ return ut.load_model(model_file, cfg)
114
+
115
+
116
+ @st.cache_resource
117
+ def find_and_load_model(model_date="20240104-223349"):
118
+ return ut.find_and_load_model(model_date)
119
+
120
+
121
+ def create_logger(name, level="DEBUG", file=None):
122
+ logger = logging.getLogger(name)
123
+ logger.propagate = False
124
+ logger.setLevel(level)
125
+ if sum([isinstance(handler, logging.StreamHandler) for handler in logger.handlers]) == 0:
126
+ ch = logging.StreamHandler()
127
+ ch.setFormatter(
128
+ logging.Formatter(
129
+ "%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s",
130
+ "%m-%d %H:%M:%S",
131
+ )
132
+ )
133
+ logger.addHandler(ch)
134
+ if file is not None:
135
+ if sum([isinstance(handler, logging.FileHandler) for handler in logger.handlers]) == 0:
136
+ ch = logging.FileHandler(file, "w")
137
+ ch.setFormatter(
138
+ logging.Formatter(
139
+ "%(asctime)s.%(msecs)03d-%(name)s-p%(process)s-{%(pathname)s:%(lineno)d}-%(levelname)s >>> %(message)s",
140
+ "%m-%d %H:%M:%S",
141
+ )
142
+ )
143
+ logger.addHandler(ch)
144
+ logger.debug("Logger added")
145
+ return logger
146
+
147
+
148
+ if "logger" not in st.session_state:
149
+ st.session_state["logger"] = create_logger(name="app", level="DEBUG", file="log_for_app.log")
150
+
151
+
152
+ @st.cache_data
153
+ def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH):
154
+ return ut.download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH)
155
+
156
+
157
+ EXAMPLE_ASC_FILES = download_example_ascs(
158
+ EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH
159
+ )
160
+
161
+
162
+ def asc_to_trial_ids(asc_file, close_gap_between_words=True):
163
+ return ut.asc_to_trial_ids(asc_file, close_gap_between_words)
164
+
165
+
166
+ @st.cache_data
167
+ def get_trials_list(asc_file=None, close_gap_between_words=True):
168
+ return ut.get_trials_list(asc_file, close_gap_between_words)
169
+
170
+
171
+ @st.cache_data
172
+ def prep_data_for_dist(model_cfg, dffix, trial=None):
173
+ return ut.prep_data_for_dist(model_cfg, dffix, trial)
174
+
175
+
176
+ def save_trial_to_json(trial, savename):
177
+ return ut.save_trial_to_json(trial, savename)
178
+
179
+
180
+ def export_csv(dffix, trial):
181
+ return ut.export_csv(dffix, trial)
182
+
183
+
184
+ @st.cache_data
185
+ def get_DIST_preds(dffix, trial):
186
+ return ut.get_DIST_preds(dffix, trial)
187
+
188
+
189
+ @st.cache_data
190
+ def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None):
191
+ return ut.get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg)
192
+
193
+
194
+ def get_all_classic_preds(dffix, trial):
195
+ return ut.get_all_classic_preds(dffix, trial)
196
+
197
+
198
+ def apply_woc(dffix, trial, corrections, algo_choice):
199
+ return ut.apply_woc(dffix, trial, corrections, algo_choice)
200
+
201
+
202
+ @st.cache_data
203
+ def correct_df(
204
+ dffix,
205
+ algo_choice,
206
+ trial=None,
207
+ for_multi=False,
208
+ ensemble_model_avg=None,
209
+ ):
210
+ return ut.correct_df(
211
+ dffix,
212
+ algo_choice,
213
+ trial,
214
+ for_multi,
215
+ ensemble_model_avg,
216
+ )
217
+
218
+
219
+ @st.cache_data
220
+ def get_font_and_font_size_from_trial(trial):
221
+ return ut.get_font_and_font_size_from_trial(trial)
222
+
223
+
224
+ @st.cache_data
225
+ def add_default_font_and_character_props_to_state(trial):
226
+ return ut.add_default_font_and_character_props_to_state(trial)
227
+
228
+
229
+ @st.cache_data
230
+ def get_plot_props(trial, available_fonts):
231
+ return ut.get_plot_props(trial, available_fonts)
232
+
233
+
234
+ def process_trial_choice(trial_id, algo_choice):
235
+ if isinstance(trial_id, dict):
236
+ trial_id = trial_id["value"]
237
+ trials_by_ids = st.session_state["trials_by_ids"]
238
+ trial = trials_by_ids[trial_id]
239
+ if "chars_list" in trial:
240
+ (
241
+ y_diff,
242
+ x_txt_start,
243
+ y_txt_start,
244
+ font_face,
245
+ _,
246
+ line_height,
247
+ ) = add_default_font_and_character_props_to_state(trial)
248
+ font_size = ut.set_font_from_chars_list(trial)
249
+
250
+ st.session_state["y_diff_for_eyekit"] = y_diff
251
+ st.session_state["x_txt_start_for_eyekit"] = x_txt_start
252
+ st.session_state["y_txt_start_for_eyekit"] = y_txt_start
253
+ st.session_state["font_face_for_eyekit"] = font_face
254
+ st.session_state["font_size_for_eyekit"] = font_size
255
+ st.session_state["line_height_for_eyekit"] = line_height
256
+
257
+ if "dffix" in trial:
258
+ dffix = trial["dffix"]
259
+ else:
260
+ asc_file = st.session_state["asc_file"]
261
+ trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file.stem}_{trial_id}_2ndInput_chars_channel_sep.png"))
262
+ trial["fname"] = str(asc_file.name).split(".")[0]
263
+ df, dffix, trial = ut.trial_to_dfs(trial, st.session_state["lines"], use_synctime=True)
264
+ st.session_state["logger"].info(f"dffix.columns after trial_to_dfs {dffix.columns}")
265
+
266
+ font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS)
267
+ st.session_state["trial"] = trial
268
+ if "chars_list" in trial:
269
+ chars_df = pd.DataFrame(trial["chars_list"])
270
+ trial["chars_df"] = chars_df.to_dict()
271
+ trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
272
+ if algo_choice is not None and ("chars_list" in trial or "words_list" in trial):
273
+ dffix, _ = correct_df(dffix, algo_choice, trial)
274
+ else:
275
+ st.warning("🚨 Stimulus information needed for fixation correction 🚨")
276
+
277
+ return dffix, trial, dpi, screen_res, font, font_size
278
+
279
+
280
+ @st.cache_data
281
+ def process_trial_choice_single_csv(trial, algo_choice, file=None):
282
+ return ut.process_trial_choice_single_csv(trial, algo_choice, file=file)
283
+
284
+
285
+ def quick_dffix_save(dffix, savename):
286
+ dffix.to_csv(savename)
287
+ st.session_state["logger"].info(f"Saved processed data as {savename}")
288
+
289
+
290
+ def save_trial_to_json(trial, savename):
291
+ if "dffix" in trial:
292
+ trial.pop("dffix")
293
+ with open(savename, "w", encoding="utf-8") as f:
294
+ json.dump(trial, f, ensure_ascii=False, indent=4, cls=ut.NumpyEncoder)
295
+
296
+
297
+ @st.cache_data
298
+ def process_trial(trial, asc_file_stem, lines, algo_choice, for_multi=False):
299
+ trial_id = trial["trial_id"]
300
+ trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}_2ndInput_chars_channel_sep.png"))
301
+ trial["fname"] = str(asc_file_stem)
302
+ font, font_size, dpi, screen_res = ut.get_plot_props(trial, AVAILABLE_FONTS)
303
+ trial["font"] = font
304
+ trial["font_size"] = font_size
305
+ trial["dpi"] = dpi
306
+ trial["screen_res"] = screen_res
307
+ df, dffix, trial = ut.trial_to_dfs(trial, lines, use_synctime=True)
308
+ if dffix.empty:
309
+ return pd.DataFrame(), trial
310
+
311
+ chars_df = pd.DataFrame(trial["chars_list"])
312
+ trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
313
+
314
+ trial["chars_df"] = chars_df.to_dict()
315
+ trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
316
+ if algo_choice is not None:
317
+ dffix = correct_df(dffix, algo_choice, trial, for_multi)
318
+
319
+ return dffix, trial
320
+
321
+
322
+ def add_text_to_ax(
323
+ chars_list,
324
+ ax,
325
+ font_to_use="DejaVu Sans Mono",
326
+ fontsize=21,
327
+ prefix="char",
328
+ plot_boxes=True,
329
+ plot_text=True,
330
+ box_annotations=None,
331
+ ):
332
+ return ut.add_text_to_ax(
333
+ chars_list,
334
+ ax,
335
+ font_to_use=font_to_use,
336
+ fontsize=fontsize,
337
+ prefix=prefix,
338
+ plot_boxes=plot_boxes,
339
+ plot_text=plot_text,
340
+ box_annotations=box_annotations,
341
+ )
342
+
343
+
344
+ @st.cache_data
345
+ def matplotlib_plot_df(
346
+ dffix,
347
+ trial,
348
+ algo_choice,
349
+ stimulus_prefix="word",
350
+ desired_dpi=300,
351
+ fix_to_plot=[],
352
+ stim_info_to_plot=["Words", "Word boxes"],
353
+ box_annotations=None,
354
+ ):
355
+ return ut.matplotlib_plot_df(
356
+ dffix,
357
+ trial,
358
+ algo_choice,
359
+ stimulus_prefix=stimulus_prefix,
360
+ desired_dpi=desired_dpi,
361
+ fix_to_plot=fix_to_plot,
362
+ stim_info_to_plot=stim_info_to_plot,
363
+ box_annotations=box_annotations,
364
+ )
365
+
366
+
367
+ def sigmoid(x):
368
+ return 1 / (1 + np.exp(-1 * x))
369
+
370
+
371
+ @st.cache_data
372
+ def plotly_plot_with_image(
373
+ dffix,
374
+ trial,
375
+ algo_choice,
376
+ to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"],
377
+ scale_factor=0.5,
378
+ ):
379
+ return ut.plotly_plot_with_image(
380
+ dffix,
381
+ trial,
382
+ algo_choice,
383
+ to_plot_list=to_plot_list,
384
+ scale_factor=scale_factor,
385
+ )
386
+
387
+
388
+ @st.cache_data
389
+ def plot_y_corr(dffix, algo_choice):
390
+ return ut.plot_y_corr(dffix, algo_choice)
391
+
392
+
393
+ def plotly_df(
394
+ dffix=None, trial=None, algo_choice=None, to_plot_list=["fixations", "characters", "corrected fixations"], title=""
395
+ ):
396
+ if dffix is None:
397
+ dffix = st.session_state["dffix"]
398
+ if algo_choice is None:
399
+ algo_choice = st.session_state["algo_choice"]
400
+
401
+ st.session_state["logger"].info(f"Plotting {to_plot_list}")
402
+
403
+ num_datapoints = dffix.index
404
+ if trial is None:
405
+ if title in st.session_state["results"]:
406
+ chars_df = pd.DataFrame(st.session_state["results"][title]["trial"]["chars_list"])
407
+ else:
408
+ chars_df = pd.DataFrame(st.session_state["trial"]["chars_df"])
409
+ else:
410
+ chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None
411
+ if chars_df is not None:
412
+ font_face, font_size = get_font_and_font_size_from_trial(trial)
413
+ font_size = font_size * 0.65 # guess for scaling
414
+ xmin = chars_df.char_x_center.min()
415
+ xmax = chars_df.char_x_center.max()
416
+ ymin = chars_df.char_y_center.min()
417
+ ymax = chars_df.char_y_center.max()
418
+ else:
419
+ st.warning("No character or word information available to plot")
420
+ xmin = dffix.x.min()
421
+ xmax = dffix.x.max()
422
+ ymin = dffix.y.min()
423
+ ymax = dffix.y.max()
424
+
425
+ layout = dict(
426
+ plot_bgcolor="white",
427
+ autosize=True,
428
+ margin=dict(t=1, l=10, r=10, b=1),
429
+ xaxis=dict(
430
+ title="x-coordinate",
431
+ linecolor="black",
432
+ range=[xmin - 100, xmax + 100],
433
+ showgrid=False,
434
+ mirror="all",
435
+ showline=True,
436
+ ),
437
+ yaxis=dict(
438
+ title="y-coordinate",
439
+ range=[ymax + 100, ymin - 100],
440
+ linecolor="black",
441
+ showgrid=False,
442
+ mirror="all",
443
+ showline=True,
444
+ ),
445
+ legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8),
446
+ )
447
+
448
+ fig = go.Figure(layout=layout)
449
+
450
+ if "Uncorrected Fixations" in to_plot_list:
451
+ duration_scaled = dffix.duration - dffix.duration.min()
452
+ duration = ((duration_scaled + 0.1) / duration_scaled.median()) * 5
453
+ fig.add_trace(
454
+ go.Scatter(
455
+ x=dffix.x,
456
+ y=dffix.y,
457
+ mode="markers+lines+text",
458
+ name="Raw fixations",
459
+ marker=dict(
460
+ symbol="arrow",
461
+ size=duration.values,
462
+ angleref="previous",
463
+ ),
464
+ line_width=1.2,
465
+ text=num_datapoints,
466
+ textposition="middle right",
467
+ textfont=dict(
468
+ family="sans serif",
469
+ size=9,
470
+ ),
471
+ hoverinfo="text+x+y",
472
+ opacity=0.6,
473
+ )
474
+ )
475
+ if "Corrected Fixations" in to_plot_list:
476
+ if isinstance(algo_choice, list):
477
+ algo_choices = algo_choice
478
+ repeats = range(len(algo_choice))
479
+ else:
480
+ algo_choices = [algo_choice]
481
+ repeats = range(1)
482
+ for algoIdx in repeats:
483
+ algo_choice = algo_choices[algoIdx]
484
+ if f"y_{algo_choice}" in dffix.columns:
485
+ fig.add_trace(
486
+ go.Scatter(
487
+ x=dffix.x,
488
+ y=dffix.loc[:, f"y_{algo_choice}"],
489
+ mode="markers",
490
+ name=f"{algo_choice} corrected",
491
+ marker_color=st.session_state["COLORS"][algoIdx],
492
+ marker_size=5,
493
+ hoverinfo="text+x+y",
494
+ opacity=0.75,
495
+ )
496
+ )
497
+ if "Characters" in to_plot_list and chars_df is not None:
498
+ fig.add_trace(
499
+ go.Scatter(
500
+ x=chars_df.char_x_center,
501
+ y=chars_df.char_y_center,
502
+ mode="markers+text",
503
+ name="",
504
+ showlegend=False,
505
+ text=chars_df.char,
506
+ textposition="middle center",
507
+ marker=dict(color="black", size=0.1),
508
+ textfont=dict(family=font_face, size=font_size, color="Black"),
509
+ )
510
+ )
511
+
512
+ if "Character boxes (slow to plot)" in to_plot_list and chars_df is not None:
513
+ num = 0
514
+ for k, row in stqdm(chars_df.iterrows(), "Adding boxes"):
515
+ fig.add_shape(
516
+ type="rect",
517
+ x0=row.char_xmin,
518
+ y0=row.char_ymin,
519
+ x1=row.char_xmax,
520
+ y1=row.char_ymax,
521
+ line=dict(color=st.session_state["COLORS"][-1], width=1),
522
+ )
523
+ num += 1
524
+ return fig
525
+
526
+
527
+ def save_to_zips(folder, pattern, savename):
528
+ if os.path.exists(TEMP_FOLDER.joinpath(savename)):
529
+ mode = "a"
530
+ else:
531
+ mode = "w"
532
+ for idx, f in enumerate(folder.glob(pattern)):
533
+ with zipfile.ZipFile(TEMP_FOLDER.joinpath(savename), mode=mode) as archive:
534
+ archive.write(f)
535
+ st.session_state["logger"].info(f"Written {f} to zip {TEMP_FOLDER.joinpath(savename)}")
536
+ if idx == 1:
537
+ mode = "a"
538
+ st.session_state["logger"].info("Done zipping")
539
+
540
+
541
+ def process_multiple_asc(asc_files):
542
+ algo_choice = st.session_state["algo_choice_multi"]
543
+ if algo_choice is not None and "DIST" in algo_choice:
544
+ model, model_cfg = find_and_load_model(model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"])
545
+ model = st.session_state["single_DIST_model"]
546
+ model_cfg = st.session_state["single_DIST_model_cfg"]
547
+ st.session_state["logger"].info(f"process_multiple_asc loaded model")
548
+ else:
549
+ model, model_cfg = None, None
550
+ zipfiles_with_results = []
551
+ st.session_state["logger"].info(f"found asc_files {asc_files}")
552
+
553
+ for asc_file in stqdm(asc_files, desc="Processing asc files"):
554
+ st.session_state["logger"].info(f"processing asc_file {asc_file}")
555
+ asc_file_stem = pl.Path(asc_file.name).stem
556
+ trials_by_ids, lines = asc_to_trial_ids(asc_file)
557
+ for trial_id, trial in stqdm(trials_by_ids.items(), desc=f"\nProcessing trials in {asc_file_stem}"):
558
+ dffix, trial = process_trial(
559
+ trial,
560
+ asc_file_stem,
561
+ lines,
562
+ algo_choice,
563
+ True,
564
+ )
565
+
566
+ st.session_state["logger"].debug(f"dffix.columns after process trial {dffix.columns}")
567
+ if dffix.empty:
568
+ st.session_state["logger"].warning(f"Dataframe for {trial_id} is empty, skipping")
569
+ continue
570
+ st.session_state["results"][f"{asc_file_stem}_{trial_id}"] = {
571
+ "trial": trial,
572
+ "dffix": dffix,
573
+ }
574
+ st.session_state["logger"].debug(f"Added {asc_file_stem}_{trial_id} to st.session_state")
575
+ quick_dffix_save(dffix, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.csv"))
576
+ save_trial_to_json(trial, TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.json"))
577
+ ut.plot_fixations_and_text(
578
+ dffix,
579
+ trial,
580
+ save=True,
581
+ savelocation=TEMP_FOLDER.joinpath(f"{asc_file_stem}_{trial_id}.png"),
582
+ algo_choice=algo_choice,
583
+ turn_axis_on=False,
584
+ )
585
+ if os.path.exists(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip")):
586
+ os.remove(TEMP_FOLDER.joinpath(f"{asc_file_stem}.zip"))
587
+ save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.csv", f"{asc_file_stem}.zip")
588
+ save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.json", f"{asc_file_stem}.zip")
589
+ save_to_zips(TEMP_FOLDER, f"{asc_file_stem}*.png", f"{asc_file_stem}.zip")
590
+ zipfiles_with_results += [str(x) for x in TEMP_FOLDER.glob(f"{asc_file_stem}*.zip")]
591
+ results_keys = list(st.session_state["results"].keys())
592
+ st.session_state["logger"].debug(f"results_keys are {results_keys}")
593
+ st.session_state["trial_choices_multi"] = results_keys
594
+ st.session_state["zipfiles_with_results"] = zipfiles_with_results
595
+ return (zipfiles_with_results, results_keys)
596
+
597
+
598
+ @st.cache_data
599
+ def get_trials_and_lines_from_asc_files(asc_files):
600
+ list_of_trial_lists = []
601
+ list_of_lines = []
602
+ total_num_trials = 0
603
+
604
+ asc_files_to_do = []
605
+ for filename_full in asc_files:
606
+ if hasattr(filename_full, "name") and not isinstance(filename_full, pl.Path):
607
+ file = filename_full.name
608
+ st.session_state["logger"].info(f"Filename is {file}, filename_full is {filename_full}")
609
+ else:
610
+ file = filename_full
611
+ if not isinstance(file, str):
612
+ file_stem = pl.Path(file.name).stem
613
+ else:
614
+ file_stem = pl.Path(file).stem
615
+ savefolder = gradio_temp_unzipped_folder.joinpath(file_stem)
616
+ st.session_state["logger"].info(f"Operating on file {file}")
617
+ if ".zip" in file:
618
+ with zipfile.ZipFile(filename_full, "r") as z:
619
+ z.extractall(str(savefolder))
620
+ elif ".tar" in file:
621
+ shutil.unpack_archive(file, savefolder, "tar")
622
+ elif ".asc" in file:
623
+ asc_files_to_do.append(filename_full)
624
+ else:
625
+ st.session_state["logger"].warning(f"Unsopported file format found in files")
626
+ newfiles = [str(x) for x in savefolder.glob(f"*.asc")]
627
+ asc_files_to_do += newfiles
628
+ st.session_state["logger"].info(f"asc_files_to_do is {asc_files_to_do}")
629
+
630
+ for asc_file in asc_files_to_do:
631
+ trials_by_ids, lines = asc_to_trial_ids(asc_file)
632
+ total_num_trials += len(trials_by_ids)
633
+ list_of_trial_lists.append(trials_by_ids)
634
+ list_of_lines.append(lines)
635
+ st.session_state["list_of_trial_lists"] = list_of_trial_lists
636
+ st.session_state["list_of_lines"] = list_of_lines
637
+ process_multiple_asc(st.session_state["multi_asc_filelist"])
638
+
639
+
640
+ def process_trial_choice_and_update_df_multi():
641
+ trial_id = st.session_state["trial_id_multi"]
642
+ dffix = st.session_state["results"][trial_id]["dffix"]
643
+ if "start_time" in dffix.columns:
644
+ dffix = dffix.drop(axis=1, labels=["start_time", "end_time"])
645
+ st.session_state["dffix_multi"] = dffix
646
+ st.session_state["trial_multi"] = st.session_state["results"][trial_id]["trial"]
647
+
648
+
649
+ @st.cache_data
650
+ def convert_df(df):
651
+ return df.to_csv(index=False).encode("utf-8")
652
+
653
+
654
+ def make_trial_from_stimulus_df(
655
+ stim_plot_df,
656
+ filename,
657
+ trial_id,
658
+ ):
659
+ chars_list = []
660
+ words_list = []
661
+ word_start_idx = 0
662
+ for idx, row in stim_plot_df.reset_index().iterrows():
663
+ char_dict = dict(
664
+ char_xmin=row[st.session_state["x_start_col_name_fix_stim"]],
665
+ char_xmax=row[st.session_state["x_end_col_name_fix_stim"]],
666
+ char_ymin=row[st.session_state["y_start_col_name_fix_stim"]],
667
+ char_ymax=row[st.session_state["y_end_col_name_fix_stim"]],
668
+ char_x_center=row[st.session_state["x_col_name_fix_stim"]],
669
+ char_y_center=row[st.session_state["y_col_name_fix_stim"]],
670
+ char=row[st.session_state["char_col_name_fix_stim"]],
671
+ assigned_line=int(row[st.session_state["line_num_col_name_stim"]]),
672
+ )
673
+ chars_list.append(char_dict)
674
+
675
+ if len(chars_list) > 1 and (
676
+ char_dict["char"] == " "
677
+ or (len(chars_list) > 2 and (chars_list[-1]["char_xmin"] < chars_list[-2]["char_xmin"]))
678
+ ):
679
+ word_dict = dict(
680
+ word_xmin=chars_list[word_start_idx]["char_xmin"],
681
+ word_xmax=chars_list[-2]["char_xmax"],
682
+ word_ymin=chars_list[word_start_idx]["char_ymin"],
683
+ word_ymax=chars_list[word_start_idx]["char_ymax"],
684
+ word_x_center=(chars_list[-2]["char_xmax"] - chars_list[word_start_idx]["char_xmin"]) / 2
685
+ + chars_list[word_start_idx]["char_xmin"],
686
+ word_y_center=(chars_list[word_start_idx]["char_ymax"] - chars_list[word_start_idx]["char_ymin"]) / 2
687
+ + chars_list[word_start_idx]["char_ymin"],
688
+ word="".join([chars_list[idx]["char"] for idx in range(word_start_idx, len(chars_list) - 1)]),
689
+ )
690
+
691
+ if char_dict["char"] != " ":
692
+ word_start_idx = idx
693
+ else:
694
+ word_start_idx = idx + 1
695
+ words_list.append(word_dict)
696
+
697
+ line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list]
698
+ line_xcoords_all = [x["char_x_center"] for x in chars_list]
699
+ line_xcoords_no_pad = np.unique(line_xcoords_all)
700
+
701
+ line_ycoords_all = [x["char_y_center"] for x in chars_list]
702
+ line_ycoords_no_pad = np.unique(line_ycoords_all)
703
+
704
+ trial = dict(
705
+ filename=filename,
706
+ y_midline=[float(x) for x in list(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique())],
707
+ num_char_lines=len(stim_plot_df[st.session_state["y_col_name_fix_stim"]].unique()),
708
+ y_diff=[
709
+ float(x) for x in list(np.unique(np.diff(stim_plot_df[st.session_state["y_start_col_name_fix_stim"]])))
710
+ ],
711
+ trial_id=trial_id,
712
+ chars_list=chars_list,
713
+ words_list=words_list,
714
+ trial_is="paragraph",
715
+ text="".join([x["char"] for x in chars_list]),
716
+ )
717
+
718
+ trial["x_char_unique"] = [float(x) for x in list(line_xcoords_no_pad)]
719
+ trial["y_char_unique"] = list(map(float, list(line_ycoords_no_pad)))
720
+ x_diff, y_diff = ut.calc_xdiff_ydiff(
721
+ line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False
722
+ )
723
+ trial["x_diff"] = float(x_diff)
724
+ trial["y_diff"] = float(y_diff)
725
+ trial["num_char_lines"] = len(line_ycoords_no_pad)
726
+ trial["line_heights"] = list(map(float, line_heights))
727
+ trial["chars_list"] = chars_list
728
+
729
+ return trial
730
+
731
+
732
+ @st.cache_data
733
+ def get_fixations_file_trials_list(fixations_df, stimulus):
734
+ if isinstance(stimulus, pd.DataFrame):
735
+ stimulus[st.session_state["line_num_col_name_stim"]] -= stimulus[
736
+ st.session_state["line_num_col_name_stim"]
737
+ ].min()
738
+ stimulus.rename(
739
+ {
740
+ st.session_state["x_col_name_fix_stim"]: "char_x_center",
741
+ st.session_state["x_start_col_name_fix_stim"]: "char_xmin",
742
+ st.session_state["x_end_col_name_fix_stim"]: "char_xmax",
743
+ st.session_state["y_col_name_fix_stim"]: "char_y_center",
744
+ st.session_state["y_start_col_name_fix_stim"]: "char_ymin",
745
+ st.session_state["y_end_col_name_fix_stim"]: "char_ymax",
746
+ st.session_state["char_col_name_fix_stim"]: "char",
747
+ st.session_state["trial_id_col_name_stim"]: "trial_id",
748
+ },
749
+ axis=1,
750
+ inplace=True,
751
+ )
752
+
753
+ fixations_df.rename(
754
+ mapper={
755
+ st.session_state["x_col_name_fix"]: "x",
756
+ st.session_state["y_col_name_fix"]: "y",
757
+ st.session_state["time_start_col_name_fix"]: "corrected_start_time",
758
+ st.session_state["time_stop_col_name_fix"]: "corrected_end_time",
759
+ st.session_state["trial_id_col_name_fix"]: "trial_id",
760
+ },
761
+ axis=1,
762
+ inplace=True,
763
+ )
764
+
765
+ fixations_df["duration"] = fixations_df.corrected_end_time - fixations_df.corrected_start_time
766
+ if "trial_id" in stimulus:
767
+ fixations_df["trial_id"] = stimulus["trial_id"]
768
+ if "trial_id" in fixations_df:
769
+ if st.session_state["has_multiple_subject"]:
770
+ fixations_df["trial_id"] = [
771
+ f"{id}_{num}"
772
+ for id, num in zip(
773
+ fixations_df[st.session_state["subject_col_name_fix"]],
774
+ fixations_df[st.session_state["trial_id_col_name_fix"]],
775
+ )
776
+ ]
777
+ trial_keys = list(fixations_df[st.session_state["trial_id_col_name_fix"]].unique())
778
+ st.session_state["logger"].info(f"Found keys {trial_keys} for {st.session_state['single_csv_file'].name}")
779
+ else:
780
+ st.session_state["logger"].warning(f"trial id column not found assigning trial id trial_0.")
781
+ st.warning(f"trial id column not found assigning trial id trial_0.")
782
+ fixations_df["trial_id"] = "trial_0"
783
+ st.session_state["fixations_df"] = fixations_df
784
+ trials_by_ids = {}
785
+
786
+ for trial_id, subdf in fixations_df.groupby("trial_id"):
787
+ if isinstance(stimulus, pd.DataFrame):
788
+ stim_df = stimulus[stimulus.trial_id == trial_id]
789
+
790
+ stim_df = stim_df.dropna(axis=0, how="any")
791
+ subdf = subdf.dropna(axis=0, how="any")
792
+ subdf = subdf.reset_index(drop=True)
793
+ stim_df = stim_df.reset_index(drop=True)
794
+ assert not stim_df.empty, "stimulus df is empty"
795
+ trial = make_trial_from_stimulus_df(
796
+ stim_df,
797
+ st.session_state["single_csv_file_stim"].name,
798
+ trial_id,
799
+ )
800
+ else:
801
+ trial = stimulus
802
+ trial["dffix"] = subdf
803
+ trial["fname"] = f"{trial_id}"
804
+ trial["plot_file"] = str(
805
+ st.session_state["PLOTS_FOLDER"].joinpath(f"{trial_id}_2ndInput_chars_channel_sep.png")
806
+ )
807
+ trials_by_ids[trial_id] = trial
808
+
809
+ return trials_by_ids, trial_keys
810
+
811
+
812
+ def try_reading_csv(file):
813
+ stringio = StringIO(file.getvalue().decode("utf-8"))
814
+ colname_mapping = {}
815
+ try:
816
+ df = pd.read_csv(stringio)
817
+ st.session_state["logger"].info(f"\n{df.head()}")
818
+ col_list = df.columns
819
+ assert len(col_list) > 1
820
+ return df
821
+ except Exception as e:
822
+ st.session_state["logger"].warn(e)
823
+ try:
824
+ df = pd.read_csv(StringIO(file.getvalue().decode("utf-8")), delimiter="\t")
825
+ col_list = df.columns
826
+ assert len(col_list) > 1
827
+ return df
828
+ except Exception as e:
829
+ st.session_state["logger"].warn(e)
830
+ return None
831
+
832
+
833
+ @st.cache_data
834
+ def guess_col_names_fix(file=None):
835
+ if file is None:
836
+ file = st.session_state["single_csv_file"]
837
+ if file is None:
838
+ return None
839
+
840
+ first_line = next(iter(StringIO(file.getvalue().decode("utf-8"))))
841
+ res = re.findall(r"[^()0-9-]+", first_line)
842
+ for delim in [",", "\t", ";"]:
843
+ first_line = first_line.split(delim)
844
+ if len(first_line) > 2:
845
+ break
846
+ else:
847
+ first_line = first_line[0]
848
+ scores_lists = {}
849
+ for k, v in st.session_state["colnames_custom_csv_fix"].items():
850
+ scores_lists[v] = []
851
+ for word in first_line:
852
+ scores_lists[v].append(jf.levenshtein_distance(v, word))
853
+ scores_df = pd.DataFrame(scores_lists)
854
+ scores_df.idxmin(axis=0)
855
+ df = try_reading_csv(file)
856
+ if df.shape[1] > 1:
857
+ return df
858
+ else:
859
+ return None
860
+
861
+
862
+ @st.cache_data
863
+ def guess_col_names_stim(file=None):
864
+ if file is None:
865
+ file = st.session_state["single_csv_file_stim"]
866
+ if file is None:
867
+ return None
868
+ if ".json" in file.name:
869
+ json_string = file.getvalue().decode("utf-8")
870
+ trial = json.loads(json_string)
871
+ return trial
872
+ else:
873
+ df = try_reading_csv(file)
874
+
875
+ if df.shape[1] > 1:
876
+ return df
877
+ else:
878
+ return None
879
+
880
+
881
+ @st.cache_resource
882
+ def set_up_models(dist_models_folder):
883
+ return ut.set_up_models(dist_models_folder)
884
+
885
+ @st.cache_data
886
+ def get_eyekit_measures(_txt, _seq, get_char_measures=False):
887
+ return ekm.get_eyekit_measures(_txt, _seq, get_char_measures=get_char_measures)
888
+
889
+
890
+ @st.cache_data
891
+ def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"):
892
+ return ut.get_all_measures(trial, dffix, prefix, use_corrected_fixations=use_corrected_fixations, correction_algo=correction_algo)
893
+
894
+
895
+ assert "ALGO_CHOICES" in st.session_state, f"st.session_state not initialized\n{list(st.session_state.keys())}"
896
+
897
+ set_up_models_out = set_up_models(DIST_MODELS_FOLDER)
898
+ st.session_state.update(set_up_models_out)
899
+
900
+
901
+ st.title("Fixation data vertical alignment")
902
+ st.header("👀 Read asc file or files and plot fixations 👀")
903
+ st.markdown("[Contact Us](mailto:[email protected])")
904
+ st.markdown("[Read about DIST model](https://arxiv.org/abs/2311.06095)")
905
+
906
+ single_file_tab, multi_file_tab = st.tabs(["Single File 📁", "Multiple Files 📁 📁"])
907
+
908
+ single_file_tab_asc_tab, single_file_tab_csv_tab = single_file_tab.tabs([".asc files", "custom files"])
909
+
910
+ single_file_tab_asc_tab.subheader(
911
+ "Upload an .asc file and select a trial. Then select a correction algorithm and plot/download the results"
912
+ )
913
+
914
+
915
+ def change_which_file_is_used_and_clear_results():
916
+ if "dffix" in st.session_state:
917
+ del st.session_state["dffix"]
918
+ if "trial" in st.session_state:
919
+ del st.session_state["trial"]
920
+ if st.session_state["single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice"] == "Example File":
921
+ st.session_state["single_asc_file_asc"] = st.session_state["single_file_tab_asc_tab_example_file_choice"]
922
+ else:
923
+ st.session_state["single_asc_file_asc"] = st.session_state["single_asc_uploaded_file"]
924
+
925
+
926
+ with single_file_tab_asc_tab.form("single_file_tab_asc_tab_load_example_form"):
927
+ single_asc_file_asc_uploaded = st.file_uploader(
928
+ "Select .asc File", accept_multiple_files=False, key="single_asc_uploaded_file", type=["asc"]
929
+ )
930
+ close_gap_between_words_single_asc = st.checkbox(
931
+ label="Should spaces between words be included in word bounding box?",
932
+ value=False,
933
+ key="close_gap_between_words_single_asc",
934
+ )
935
+
936
+ if os.path.isfile(EXAMPLE_ASC_FILES[0]):
937
+ example_file_choice = st.selectbox(
938
+ "Select example file", options=EXAMPLE_ASC_FILES, key="single_file_tab_asc_tab_example_file_choice"
939
+ )
940
+ use_example_or_uploaded_file_choice = st.radio(
941
+ "Should the uploaded file be used or the selected example file?",
942
+ index=1,
943
+ options=["Uploaded File", "Example File"],
944
+ key="single_file_tab_asc_tab_example_use_example_or_uploaded_file_choice",
945
+ )
946
+
947
+ upload_file_button = st.form_submit_button(
948
+ label="Load selected data.", on_click=change_which_file_is_used_and_clear_results
949
+ )
950
+
951
+
952
+ if "single_asc_file_asc" in st.session_state and st.session_state["single_asc_file_asc"] is not None:
953
+ trial_choices_single_asc, trials_by_ids, lines, asc_file = get_trials_list(
954
+ st.session_state["single_asc_file_asc"], close_gap_between_words=close_gap_between_words_single_asc
955
+ )
956
+ st.session_state["trials_by_ids"] = trials_by_ids
957
+ st.session_state["trial_choices"] = trial_choices_single_asc
958
+ st.session_state["lines"] = lines
959
+ st.session_state["asc_file"] = asc_file
960
+ if trial_choices_single_asc:
961
+ with single_file_tab_asc_tab.form(key="single_file_tab_asc_tab_trial_select_form"):
962
+ col_a1, col_a2 = st.columns((1, 2))
963
+ with col_a1:
964
+ trial_choice = st.selectbox(
965
+ "Which trial should be corrected?",
966
+ trial_choices_single_asc,
967
+ key="trial_id",
968
+ index=0,
969
+ )
970
+ with col_a2:
971
+ st.multiselect(
972
+ "Choose correction algorithm",
973
+ ALGO_CHOICES,
974
+ key="algo_choice",
975
+ default=[ALGO_CHOICES[0]],
976
+ )
977
+ process_trial_btn = st.form_submit_button("Load and correct trial")
978
+
979
+ if process_trial_btn:
980
+ single_file_tab_asc_tab.write(f'You selected: {st.session_state["trial_id"]}')
981
+ dffix, trial, dpi, screen_res, font, font_size = process_trial_choice(
982
+ trial_choice, st.session_state["algo_choice"]
983
+ )
984
+
985
+ st.session_state["dffix"] = dffix
986
+ st.session_state["trial"] = trial
987
+ st.session_state["dpi"] = dpi
988
+ st.session_state["screen_res"] = screen_res
989
+ st.session_state["font"] = font
990
+ st.session_state["font_size"] = font_size
991
+
992
+ export_csv(dffix, trial)
993
+
994
+ if "dffix" in st.session_state and "trial" in st.session_state:
995
+ df_expander_single = single_file_tab_asc_tab.expander("Show Dataframe", False)
996
+ plot_expander_single = single_file_tab_asc_tab.expander("Show Plots", True)
997
+ df_expander_single.dataframe(st.session_state["dffix"])
998
+
999
+ csv = convert_df(st.session_state["dffix"])
1000
+
1001
+ df_expander_single.download_button(
1002
+ "Download fixation dataframe",
1003
+ csv,
1004
+ f'{st.session_state["trial_id"]}.csv',
1005
+ "text/csv",
1006
+ key="download-csv-single",
1007
+ )
1008
+
1009
+ plotting_checkboxes_single = plot_expander_single.multiselect(
1010
+ "Select what gets plotted",
1011
+ ["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1012
+ key="plotting_checkboxes_single",
1013
+ default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1014
+ )
1015
+ scale_factor_single_asc = plot_expander_single.number_input(
1016
+ label="Scale factor for stimulus image", min_value=0.01, max_value=3.0, value=0.5, step=0.1
1017
+ )
1018
+ plot_expander_single.plotly_chart(
1019
+ plotly_plot_with_image(
1020
+ st.session_state["dffix"],
1021
+ st.session_state["trial"],
1022
+ to_plot_list=plotting_checkboxes_single,
1023
+ algo_choice=st.session_state["algo_choice"],
1024
+ scale_factor=scale_factor_single_asc,
1025
+ ),
1026
+ use_container_width=False,
1027
+ )
1028
+ plot_expander_single.plotly_chart(
1029
+ plot_y_corr(st.session_state["dffix"], st.session_state["algo_choice"]), use_container_width=True
1030
+ )
1031
+
1032
+ if "chars_list" in st.session_state["trial"]:
1033
+ analysis_expander_single_asc = single_file_tab_asc_tab.expander("Show Analysis results", True)
1034
+ use_corrected_fixations_tickbox = analysis_expander_single_asc.checkbox(
1035
+ "Use corrected",
1036
+ True,
1037
+ "use_corrected_fixations_tickbox",
1038
+ help="Whether to use the corrected or uncorrected fixations for the analysis.",
1039
+ )
1040
+ eyekit_tab, own_analysis_tab = analysis_expander_single_asc.tabs(
1041
+ ["Analysis using eyekit", "Analysis without eyekit"]
1042
+ )
1043
+ with eyekit_tab:
1044
+ st.markdown("Analysis powered by [eyekit](https://jwcarr.github.io/eyekit/)")
1045
+ st.markdown(
1046
+ "Please adjust parameters below to align fixations with stimulus using the sliders.Eyekit analysis is based on this alignment."
1047
+ )
1048
+ a_c1, a_c2, a_c3, a_c4, a_c5, a_c6 = st.columns(6)
1049
+ if "Consolas" in AVAILABLE_FONTS:
1050
+ font_index = AVAILABLE_FONTS.index("Consolas")
1051
+ elif "Courier New" in AVAILABLE_FONTS:
1052
+ font_index = AVAILABLE_FONTS.index("Courier New")
1053
+ elif "DejaVu Sans Mono" in AVAILABLE_FONTS:
1054
+ font_index = AVAILABLE_FONTS.index("DejaVu Sans Mono")
1055
+ else:
1056
+ font_index = 0
1057
+ font_face = a_c1.selectbox(
1058
+ label="Select Font",
1059
+ options=AVAILABLE_FONTS,
1060
+ index=font_index,
1061
+ key="font_face_for_eyekit_single_asc",
1062
+ )
1063
+ algo_choice_single_asc_eyekit = a_c1.selectbox(
1064
+ "Algorithm", st.session_state["algo_choice"], index=0, key="algo_choice_single_asc_eyekit"
1065
+ )
1066
+ sliders_on_tickbox = a_c6.checkbox(
1067
+ "Sliders", True, "single_asc_eyekit_sliders_checkbox", help="Turns sliders on and off"
1068
+ )
1069
+
1070
+ if "font_size_for_eyekit" not in st.session_state:
1071
+ (
1072
+ y_diff,
1073
+ x_txt_start,
1074
+ y_txt_start,
1075
+ _,
1076
+ _,
1077
+ line_height,
1078
+ ) = add_default_font_and_character_props_to_state(st.session_state["trial"])
1079
+ font_size = ut.set_font_from_chars_list(st.session_state["trial"])
1080
+ st.session_state["y_diff_for_eyekit"] = y_diff
1081
+ st.session_state["x_txt_start_for_eyekit"] = x_txt_start
1082
+ st.session_state["y_txt_start_for_eyekit"] = y_txt_start
1083
+ st.session_state["font_face_for_eyekit"] = font_face
1084
+ st.session_state["font_size_for_eyekit"] = font_size
1085
+ st.session_state["line_height_for_eyekit"] = line_height
1086
+ if sliders_on_tickbox:
1087
+ font_size = a_c2.select_slider(
1088
+ "Font Size",
1089
+ np.arange(5, 36, 0.25),
1090
+ st.session_state["font_size_for_eyekit"],
1091
+ key="font_size_for_eyekit_single_asc",
1092
+ )
1093
+ x_txt_start = a_c3.select_slider(
1094
+ "x",
1095
+ np.arange(300, 601, 1),
1096
+ round(st.session_state["x_txt_start_for_eyekit"]),
1097
+ key="x_txt_start_for_eyekit_single_asc",
1098
+ help="x coordinate of first character",
1099
+ )
1100
+ y_txt_start = a_c4.select_slider(
1101
+ "y",
1102
+ np.arange(100, 501, 1),
1103
+ round(st.session_state["y_txt_start_for_eyekit"]),
1104
+ key="y_txt_start_for_eyekit_single_asc",
1105
+ help="y coordinate of first character",
1106
+ )
1107
+ line_height = a_c5.select_slider(
1108
+ "Line height",
1109
+ np.arange(0, 151, 1),
1110
+ round(st.session_state["line_height_for_eyekit"]),
1111
+ key="line_height_for_eyekit_single_asc",
1112
+ )
1113
+ else:
1114
+ font_size = a_c2.number_input(
1115
+ "Font Size",
1116
+ None,
1117
+ None,
1118
+ round(st.session_state["font_size_for_eyekit"], ndigits=0),
1119
+ key="font_size_for_eyekit_single_asc",
1120
+ )
1121
+ x_txt_start = a_c3.number_input(
1122
+ "x",
1123
+ None,
1124
+ None,
1125
+ round(st.session_state["x_txt_start_for_eyekit"]),
1126
+ key="x_txt_start_for_eyekit_single_asc",
1127
+ help="x coordinate of first character",
1128
+ )
1129
+ y_txt_start = a_c4.number_input(
1130
+ "y",
1131
+ None,
1132
+ None,
1133
+ round(st.session_state["y_txt_start_for_eyekit"]),
1134
+ key="y_txt_start_for_eyekit_single_asc",
1135
+ help="y coordinate of first character",
1136
+ )
1137
+ line_height = a_c5.number_input(
1138
+ "Line height",
1139
+ None,
1140
+ None,
1141
+ round(st.session_state["line_height_for_eyekit"]),
1142
+ key="line_height_for_eyekit_single_asc",
1143
+ )
1144
+
1145
+ fixation_sequence, textblock, screen_size = ekm.get_fix_seq_and_text_block(
1146
+ st.session_state["dffix"],
1147
+ st.session_state["trial"],
1148
+ x_txt_start=st.session_state["x_txt_start_for_eyekit_single_asc"],
1149
+ y_txt_start=st.session_state["y_txt_start_for_eyekit_single_asc"],
1150
+ font_face=st.session_state["font_face_for_eyekit_single_asc"],
1151
+ font_size=st.session_state["font_size_for_eyekit_single_asc"],
1152
+ line_height=line_height,
1153
+ use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"],
1154
+ correction_algo=st.session_state["algo_choice_single_asc_eyekit"],
1155
+ )
1156
+ eyekitplot_img = ekm.eyekit_plot(textblock, fixation_sequence, screen_size)
1157
+ st.image(eyekitplot_img, "Fixations and stimulus as used for anaylsis")
1158
+
1159
+ with open(
1160
+ f'results/fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r"
1161
+ ) as f:
1162
+ fixation_sequence_json = json.load(f)
1163
+ fixation_sequence_json_str = json.dumps(fixation_sequence_json)
1164
+
1165
+ st.download_button(
1166
+ "Download fixations in eyekits format",
1167
+ fixation_sequence_json_str,
1168
+ f'fixation_sequence_eyekit_{st.session_state["trial"]["trial_id"]}.json',
1169
+ "json",
1170
+ key="download_eyekit_fix_json_single_asc",
1171
+ )
1172
+
1173
+ with open(f'results/textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json', "r") as f:
1174
+ textblock_json = json.load(f)
1175
+ textblock_json_str = json.dumps(textblock_json)
1176
+
1177
+ st.download_button(
1178
+ "Download stimulus in eyekits format",
1179
+ textblock_json_str,
1180
+ f'textblock_eyekit_{st.session_state["trial"]["trial_id"]}.json',
1181
+ "json",
1182
+ key="download_eyekit_text_json_single_asc",
1183
+ )
1184
+
1185
+ word_measures_df, character_measures_df = get_eyekit_measures(
1186
+ textblock, fixation_sequence, get_char_measures=False
1187
+ )
1188
+
1189
+ st.dataframe(word_measures_df, use_container_width=True, hide_index=True)
1190
+ word_measures_df_csv = convert_df(word_measures_df)
1191
+
1192
+ word_measures_df_download_btn = st.download_button(
1193
+ "Download word measures data",
1194
+ word_measures_df_csv,
1195
+ f'{st.session_state["trial"]["trial_id"]}_word_measures_df.csv',
1196
+ "text/csv",
1197
+ key="word_measures_df_download_btn",
1198
+ )
1199
+ measure_words = st.selectbox(
1200
+ "Select measure to visualize", list(ekm.MEASURES_DICT.keys()), key="measure_words"
1201
+ )
1202
+ st.image(ekm.plot_with_measure(textblock, fixation_sequence, screen_size, measure_words))
1203
+ with own_analysis_tab:
1204
+ st.markdown(
1205
+ "This analysis method does not require manual alignment and works when the automated stimulus coordinates are correct."
1206
+ )
1207
+ own_word_measures = get_all_measures(
1208
+ st.session_state["trial"],
1209
+ st.session_state["dffix"],
1210
+ prefix="word",
1211
+ use_corrected_fixations=st.session_state["use_corrected_fixations_tickbox"],
1212
+ correction_algo=st.session_state["algo_choice_single_asc_eyekit"],
1213
+ )
1214
+ st.dataframe(own_word_measures, use_container_width=True, hide_index=True)
1215
+ own_word_measures_csv = convert_df(own_word_measures)
1216
+
1217
+ word_measures_df_download_btn = st.download_button(
1218
+ "Download word measures data",
1219
+ own_word_measures_csv,
1220
+ f'{st.session_state["trial"]["trial_id"]}_own_word_measures_df.csv',
1221
+ "text/csv",
1222
+ key="own_word_measures_df_download_btn",
1223
+ )
1224
+ fix_to_plot = (
1225
+ ["Corrected Fixations"]
1226
+ if st.session_state["use_corrected_fixations_tickbox"]
1227
+ else ["Uncorrected Fixations"]
1228
+ )
1229
+ own_word_measures_fig, desired_width_in_pixels, desired_height_in_pixels = matplotlib_plot_df(
1230
+ st.session_state["dffix"],
1231
+ st.session_state["trial"],
1232
+ st.session_state["algo_choice"],
1233
+ stimulus_prefix="word",
1234
+ box_annotations=own_word_measures[measure_words],
1235
+ fix_to_plot=fix_to_plot,
1236
+ )
1237
+ st.pyplot(own_word_measures_fig)
1238
+ else:
1239
+ single_file_tab_asc_tab.warning("🚨 Stimulus information needed for analysis 🚨")
1240
+
1241
+ single_file_tab_csv_tab.markdown(
1242
+ "#### Upload one .csv file for the fixations and one .json or .csv file for the stimulus information and select a trial. Then select a correction algorithm and plot/download the results"
1243
+ )
1244
+
1245
+ with single_file_tab_csv_tab.expander("Upload and preview data", expanded=True):
1246
+ csv_upl_col1, csv_upl_col2 = st.columns(2)
1247
+ single_csv_file = csv_upl_col1.file_uploader(
1248
+ "Select .csv file containing the fixation data",
1249
+ accept_multiple_files=False,
1250
+ key="single_csv_file",
1251
+ type={"csv", "txt", "dat"},
1252
+ )
1253
+ single_csv_stim_file = csv_upl_col2.file_uploader(
1254
+ "Select .csv or .json file containing the stimulus data",
1255
+ accept_multiple_files=False,
1256
+ key="single_csv_file_stim",
1257
+ type={"json", "csv", "txt", "dat"},
1258
+ )
1259
+
1260
+ if single_csv_file:
1261
+ st.session_state["dffix_single_csv"] = guess_col_names_fix(single_csv_file)
1262
+ if st.session_state["dffix_single_csv"] is not None:
1263
+ csv_upl_col1.dataframe(
1264
+ st.session_state["dffix_single_csv"], use_container_width=True, hide_index=True, height=200
1265
+ )
1266
+ if single_csv_stim_file:
1267
+ st.session_state["stimdf_single_csv"] = guess_col_names_stim(single_csv_stim_file)
1268
+ if ".json" in single_csv_stim_file.name:
1269
+ st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].keys()
1270
+ else:
1271
+ st.session_state["colnames_stim"] = st.session_state["stimdf_single_csv"].columns
1272
+ if st.session_state["stimdf_single_csv"] is not None:
1273
+ if ".json" in single_csv_stim_file.name:
1274
+ csv_upl_col2.json(st.session_state["stimdf_single_csv"])
1275
+ else:
1276
+ csv_upl_col2.dataframe(
1277
+ st.session_state["stimdf_single_csv"], use_container_width=True, hide_index=True, height=200
1278
+ )
1279
+
1280
+ if single_csv_file and single_csv_stim_file:
1281
+ with single_file_tab_csv_tab.expander("Column names for csv files", expanded=True):
1282
+ with st.form("Column names for csv files"):
1283
+ st.markdown("### Please set column/key names for csv/json files")
1284
+ st.markdown("#### Fixation file column names:")
1285
+ c1, c2, c3 = st.columns(3)
1286
+ x_col_name_fix = c1.text_input("x coordinate", key="x_col_name_fix", value="x")
1287
+ y_col_name_fix = c2.text_input("y coordinate", key="y_col_name_fix", value="y")
1288
+ subject_col_name_fix = c1.text_input("subject id", key="subject_col_name_fix", value="sub_id")
1289
+ trial_id_col_name_fix = c3.text_input("trial id", key="trial_id_col_name_fix", value="trial_id")
1290
+ time_start_col_name_fix = c2.text_input(
1291
+ "fixation start time", key="time_start_col_name_fix", value="corrected_start_time"
1292
+ )
1293
+ time_stop_col_name_fix = c3.text_input(
1294
+ "fixation end time", key="time_stop_col_name_fix", value="corrected_end_time"
1295
+ )
1296
+ st.markdown("#### Stimulus file column/key names:")
1297
+ c1, c2, c3 = st.columns(3)
1298
+ x_col_name_fix_stim = c1.text_input("x coordinate", key="x_col_name_fix_stim", value="char_x_center")
1299
+ y_col_name_fix_stim = c2.text_input("y coordinate", key="y_col_name_fix_stim", value="char_y_center")
1300
+ x_start_col_name_fix_stim = c3.text_input(
1301
+ "x min of interest areas", key="x_start_col_name_fix_stim", value="char_xmin"
1302
+ )
1303
+ x_end_col_name_fix_stim = c1.text_input(
1304
+ "x max of interest areas", key="x_end_col_name_fix_stim", value="char_xmax"
1305
+ )
1306
+ y_start_col_name_fix_stim = c2.text_input(
1307
+ "y min of interest areas", key="y_start_col_name_fix_stim", value="char_ymin"
1308
+ )
1309
+ y_end_col_name_fix_stim = c3.text_input(
1310
+ "x max of interest areas", key="y_end_col_name_fix_stim", value="char_ymax"
1311
+ )
1312
+ char_col_name_fix_stim = c1.text_input(
1313
+ "content of interest area", key="char_col_name_fix_stim", value="char"
1314
+ )
1315
+ line_num_col_name_stim = c3.text_input(
1316
+ "line number for interest areas", key="line_num_col_name_stim", value="assigned_line"
1317
+ )
1318
+ subject_col_name_stim = c1.text_input("subject id", key="subject_col_name_stim", value="sub_id")
1319
+ trial_id_col_name_stim = c2.text_input("trial id", key="trial_id_col_name_stim", value="trial_id")
1320
+ has_multiple_subject = c2.checkbox("multiple subject in file", key="has_multiple_subject")
1321
+ form_submitted = st.form_submit_button("Confirm column/key names")
1322
+
1323
+
1324
+ if single_csv_file and single_csv_stim_file:
1325
+ process_custom_csvs_button = single_file_tab_csv_tab.button(
1326
+ "Load data from files",
1327
+ )
1328
+ if process_custom_csvs_button or "trial_keys_single_csv" in st.session_state:
1329
+ trials_by_ids, trial_keys = get_fixations_file_trials_list(
1330
+ st.session_state["dffix_single_csv"], st.session_state["stimdf_single_csv"]
1331
+ )
1332
+
1333
+ st.session_state["trials_by_ids_single_csv"] = trials_by_ids
1334
+ st.session_state["trial_keys_single_csv"] = trial_keys
1335
+ with single_file_tab_csv_tab.form(key="trial_selection_algo_selection_form_single_csv"):
1336
+ col_a1, col_a2 = st.columns((1, 2))
1337
+ with col_a1:
1338
+ trial_choice = st.selectbox(
1339
+ "Which trial should be corrected?",
1340
+ st.session_state["trial_keys_single_csv"],
1341
+ key="trial_id_selected_custom_csv",
1342
+ index=0,
1343
+ )
1344
+ with col_a2:
1345
+ algo_choice_single_csv = st.multiselect(
1346
+ "Choose correction algorithm",
1347
+ ALGO_CHOICES,
1348
+ key="algo_choice_single_csv",
1349
+ default=[ALGO_CHOICES[0]],
1350
+ )
1351
+ process_trial_btn = st.form_submit_button("Correct trial")
1352
+ if "trial_id_selected_custom_csv" in st.session_state and "algo_choice_single_csv" in st.session_state:
1353
+ trial = st.session_state["trials_by_ids_single_csv"][trial_choice]
1354
+ dffix, trial, dpi, screen_res, font, font_size = process_trial_choice_single_csv(
1355
+ trial, algo_choice_single_csv
1356
+ )
1357
+ st.session_state["trial_single_csv"] = trial
1358
+ csv = convert_df(dffix)
1359
+
1360
+ single_file_tab_csv_tab.download_button(
1361
+ "Download corrected fixation data",
1362
+ csv,
1363
+ f'{trial["trial_id"]}.csv',
1364
+ "text/csv",
1365
+ key="download-csv-custom-csv",
1366
+ )
1367
+ with single_file_tab_csv_tab.expander("Show corrected fixation data", expanded=True):
1368
+ st.dataframe(dffix, use_container_width=True, hide_index=True, height=200)
1369
+ with single_file_tab_csv_tab.expander("Show fixation plots", expanded=True):
1370
+ plotting_checkboxes_single_single_csv = st.multiselect(
1371
+ "Select what gets plotted",
1372
+ ["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1373
+ key="plotting_checkboxes_single_single_csv",
1374
+ default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1375
+ )
1376
+
1377
+ st.plotly_chart(
1378
+ plotly_plot_with_image(
1379
+ dffix,
1380
+ trial,
1381
+ to_plot_list=plotting_checkboxes_single_single_csv,
1382
+ algo_choice=algo_choice_single_csv,
1383
+ ),
1384
+ use_container_width=True,
1385
+ )
1386
+ st.plotly_chart(plot_y_corr(dffix, algo_choice_single_csv), use_container_width=True)
1387
+
1388
+
1389
+ multi_file_tab.subheader("Upload multiple .asc files. Then select a correction algorithm and download the results.")
1390
+
1391
+ with multi_file_tab.form("Upload files to be processed and select algorithm"):
1392
+ multifile_col, multi_algo_col = st.columns((1, 1))
1393
+
1394
+ with multifile_col:
1395
+ st.file_uploader(
1396
+ "Upload .asc Files", accept_multiple_files=True, key="multi_asc_filelist", type=["asc", "tar", "zip"]
1397
+ )
1398
+ with multi_algo_col:
1399
+ st.multiselect(
1400
+ "Choose correction algorithms",
1401
+ ALGO_CHOICES,
1402
+ key="algo_choice_multi",
1403
+ default=[ALGO_CHOICES[0]],
1404
+ )
1405
+ process_trial_btn_multi = st.form_submit_button("Load and correct asc files")
1406
+ if process_trial_btn_multi:
1407
+ get_trials_and_lines_from_asc_files(st.session_state["multi_asc_filelist"])
1408
+
1409
+ if "zipfiles_with_results" in st.session_state:
1410
+ multi_res_col1, multi_res_col2 = multi_file_tab.columns(2)
1411
+
1412
+ chosen_zip = multi_res_col1.selectbox("Choose results to download", st.session_state["zipfiles_with_results"])
1413
+ st.session_state["logger"].info(f"Download button for {chosen_zip}")
1414
+ st.session_state["logger"].info(st.session_state["zipfiles_with_results"])
1415
+ zipnamestem = pl.Path(chosen_zip).stem
1416
+ with open(chosen_zip, "rb") as f:
1417
+ multi_res_col2.download_button(f"Download {zipnamestem}", f, file_name=f"results_{zipnamestem}.zip")
1418
+
1419
+
1420
+ if "trial_choices_multi" in st.session_state:
1421
+ multi_plotting_options_col1, multi_plotting_options_col2 = multi_file_tab.columns(2)
1422
+
1423
+ trial_choice_multi = multi_plotting_options_col1.selectbox(
1424
+ "Which trial should be plotted?",
1425
+ st.session_state["trial_choices_multi"],
1426
+ key="trial_id_multi",
1427
+ placeholder="Select trial to display and plot",
1428
+ on_change=process_trial_choice_and_update_df_multi,
1429
+ )
1430
+
1431
+ plotting_checkboxes_multi = multi_plotting_options_col2.multiselect(
1432
+ "Select what gets plotted",
1433
+ ["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1434
+ default=["Uncorrected Fixations", "Corrected Fixations", "Words", "Word boxes"],
1435
+ )
1436
+
1437
+ if trial_choice_multi and "dffix_multi" in st.session_state:
1438
+ df_expander_multi = multi_file_tab.expander("Show Dataframe", False)
1439
+ plot_expander_multi = multi_file_tab.expander("Show Plots", True)
1440
+
1441
+ df_expander_multi.dataframe(st.session_state["dffix_multi"])
1442
+ dffix_multi = st.session_state["dffix_multi"]
1443
+ trial_multi = st.session_state["trial_multi"]
1444
+
1445
+ plot_expander_multi.plotly_chart(
1446
+ plotly_plot_with_image(
1447
+ dffix_multi, trial_multi, st.session_state["algo_choice_multi"], to_plot_list=plotting_checkboxes_multi
1448
+ ),
1449
+ use_container_width=True,
1450
+ )
1451
+ plot_expander_multi.plotly_chart(
1452
+ plot_y_corr(dffix_multi, st.session_state["algo_choice_multi"]), use_container_width=True
1453
+ )
classic_correction_algos.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly adapted from https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py
3
+ """
4
+
5
+ import numpy as np
6
+ from sklearn.cluster import KMeans
7
+
8
+
9
+ def apply_classic_algo(
10
+ dffix,
11
+ trial,
12
+ algo="slice",
13
+ algo_params=dict(x_thresh=192, y_thresh=32, w_thresh=32, n_thresh=90),
14
+ ):
15
+ fixation_array = dffix.loc[:, ["x", "y"]].values
16
+ y_diff = trial["y_diff"]
17
+ if "y_char_unique" in trial:
18
+ midlines = trial["y_char_unique"]
19
+ else:
20
+ midlines = trial["y_midline"]
21
+ if len(midlines) == 1:
22
+ corrected_fix_y_vals = np.ones((fixation_array.shape[0])) * midlines[0]
23
+ elif fixation_array.shape[0] <= 2:
24
+ corrected_fix_y_vals = np.ones((fixation_array.shape[0])) * midlines[0]
25
+
26
+ else:
27
+ if algo == "slice":
28
+ corrected_fix_y_vals = slice(fixation_array, midlines, line_height=y_diff, **algo_params)
29
+ elif algo == "warp":
30
+ word_center_list = [(word["word_x_center"], word["word_y_center"]) for word in trial["words_list"]]
31
+ corrected_fix_y_vals = warp(fixation_array, word_center_list)
32
+ elif algo == "chain":
33
+ corrected_fix_y_vals = chain(fixation_array, midlines, **algo_params)
34
+ elif algo == "cluster":
35
+ corrected_fix_y_vals = cluster(fixation_array, midlines)
36
+ elif algo == "merge":
37
+ corrected_fix_y_vals = merge(fixation_array, midlines, **algo_params)
38
+ elif algo == "regress":
39
+ corrected_fix_y_vals = regress(fixation_array, midlines, **algo_params)
40
+ elif algo == "segment":
41
+ corrected_fix_y_vals = segment(fixation_array, midlines, **algo_params)
42
+ elif algo == "split":
43
+ corrected_fix_y_vals = split(fixation_array, midlines, **algo_params)
44
+ elif algo == "stretch":
45
+ corrected_fix_y_vals = stretch(fixation_array, midlines, **algo_params)
46
+ elif algo == "attach":
47
+ corrected_fix_y_vals = attach(fixation_array, midlines)
48
+ elif algo == "compare":
49
+ word_center_list = [(word["word_x_center"], word["word_y_center"]) for word in trial["words_list"]]
50
+ n_nearest_lines = min(algo_params["n_nearest_lines"], len(midlines) - 1)
51
+ algo_params["n_nearest_lines"] = n_nearest_lines
52
+ corrected_fix_y_vals = compare(fixation_array, np.array(word_center_list), **algo_params)
53
+ else:
54
+ raise NotImplementedError(f"{algo} not implemented")
55
+
56
+ corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_fix_y_vals]
57
+ dffix[f"y_{algo}"] = corrected_fix_y_vals
58
+ dffix[f"line_num_{algo}"] = corrected_line_nums
59
+ return dffix
60
+
61
+
62
+ def slice(fixation_XY, midlines, line_height: float, x_thresh=192, y_thresh=32, w_thresh=32, n_thresh=90):
63
+ """
64
+ Adapted from Eyekit(https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py)
65
+ implementation
66
+
67
+ Form a set of runs and then reduce the set to *m* by repeatedly merging
68
+ those that appear to be on the same line. Merged sequences are then
69
+ assigned to text lines in positional order. Default params:
70
+ `x_thresh=192`, `y_thresh=32`, `w_thresh=32`, `n_thresh=90`. Requires
71
+ NumPy. Original method by [Glandorf & Schroeder (2021)](https://doi.org/10.1016/j.procs.2021.09.069).
72
+ """
73
+ fixation_XY = np.array(fixation_XY, dtype=float)
74
+ line_Y = np.array(midlines, dtype=float)
75
+ proto_lines, phantom_proto_lines = {}, {}
76
+ # 1. Segment runs
77
+ dist_X = abs(np.diff(fixation_XY[:, 0]))
78
+ dist_Y = abs(np.diff(fixation_XY[:, 1]))
79
+ end_run_indices = list(np.where(np.logical_or(dist_X > x_thresh, dist_Y > y_thresh))[0] + 1)
80
+ run_starts = [0] + end_run_indices
81
+ run_ends = end_run_indices + [len(fixation_XY)]
82
+ runs = [list(range(start, end)) for start, end in zip(run_starts, run_ends)]
83
+ # 2. Determine starting run
84
+ longest_run_i = np.argmax([fixation_XY[run[-1], 0] - fixation_XY[run[0], 0] for run in runs])
85
+ proto_lines[0] = runs.pop(longest_run_i)
86
+ # 3. Group runs into proto lines
87
+ while runs:
88
+ merger_on_this_iteration = False
89
+ for proto_line_i, direction in [(min(proto_lines), -1), (max(proto_lines), 1)]:
90
+ # Create new proto line above or below (depending on direction)
91
+ proto_lines[proto_line_i + direction] = []
92
+ # Get current proto line XY coordinates (if proto line is empty, get phanton coordinates)
93
+ if proto_lines[proto_line_i]:
94
+ proto_line_XY = fixation_XY[proto_lines[proto_line_i]]
95
+ else:
96
+ proto_line_XY = phantom_proto_lines[proto_line_i]
97
+ # Compute differences between current proto line and all runs
98
+ run_differences = np.zeros(len(runs))
99
+ for run_i, run in enumerate(runs):
100
+ y_diffs = [y - proto_line_XY[np.argmin(abs(proto_line_XY[:, 0] - x)), 1] for x, y in fixation_XY[run]]
101
+ run_differences[run_i] = np.mean(y_diffs)
102
+ # Find runs that can be merged into this proto line
103
+ merge_into_current = list(np.where(abs(run_differences) < w_thresh)[0])
104
+ # Find runs that can be merged into the adjacent proto line
105
+ merge_into_adjacent = list(
106
+ np.where(
107
+ np.logical_and(
108
+ run_differences * direction >= w_thresh,
109
+ run_differences * direction < n_thresh,
110
+ )
111
+ )[0]
112
+ )
113
+ # Perform mergers
114
+ for index in merge_into_current:
115
+ proto_lines[proto_line_i].extend(runs[index])
116
+ for index in merge_into_adjacent:
117
+ proto_lines[proto_line_i + direction].extend(runs[index])
118
+ # If no, mergers to the adjacent, create phantom line for the adjacent
119
+ if not merge_into_adjacent:
120
+ average_x, average_y = np.mean(proto_line_XY, axis=0)
121
+ adjacent_y = average_y + line_height * direction
122
+ phantom_proto_lines[proto_line_i + direction] = np.array([[average_x, adjacent_y]])
123
+ # Remove all runs that were merged on this iteration
124
+ for index in sorted(merge_into_current + merge_into_adjacent, reverse=True):
125
+ del runs[index]
126
+ merger_on_this_iteration = True
127
+ # If no mergers were made, break the while loop
128
+ if not merger_on_this_iteration:
129
+ break
130
+ # 4. Assign any leftover runs to the closest proto lines
131
+ for run in runs:
132
+ best_pl_distance = np.inf
133
+ best_pl_assignemnt = None
134
+ for proto_line_i in proto_lines:
135
+ if proto_lines[proto_line_i]:
136
+ proto_line_XY = fixation_XY[proto_lines[proto_line_i]]
137
+ else:
138
+ proto_line_XY = phantom_proto_lines[proto_line_i]
139
+ y_diffs = [y - proto_line_XY[np.argmin(abs(proto_line_XY[:, 0] - x)), 1] for x, y in fixation_XY[run]]
140
+ pl_distance = abs(np.mean(y_diffs))
141
+ if pl_distance < best_pl_distance:
142
+ best_pl_distance = pl_distance
143
+ best_pl_assignemnt = proto_line_i
144
+ proto_lines[best_pl_assignemnt].extend(run)
145
+ # 5. Prune proto lines
146
+ while len(proto_lines) > len(line_Y):
147
+ top, bot = min(proto_lines), max(proto_lines)
148
+ if len(proto_lines[top]) < len(proto_lines[bot]):
149
+ proto_lines[top + 1].extend(proto_lines[top])
150
+ del proto_lines[top]
151
+ else:
152
+ proto_lines[bot - 1].extend(proto_lines[bot])
153
+ del proto_lines[bot]
154
+ # 6. Map proto lines to text lines
155
+ for line_i, proto_line_i in enumerate(sorted(proto_lines)):
156
+ fixation_XY[proto_lines[proto_line_i], 1] = line_Y[line_i]
157
+ return fixation_XY[:, 1]
158
+
159
+
160
+ def attach(fixation_XY, line_Y):
161
+ n = len(fixation_XY)
162
+ for fixation_i in range(n):
163
+ line_i = np.argmin(abs(line_Y - fixation_XY[fixation_i, 1]))
164
+ fixation_XY[fixation_i, 1] = line_Y[line_i]
165
+ return fixation_XY[:, 1]
166
+
167
+
168
+ def chain(fixation_XY, midlines, x_thresh=192, y_thresh=32):
169
+ """
170
+ Adapted from Eyekit(https://github.com/jwcarr/eyekit/blob/350d055eecaa1581b03db5a847424825ffbb10f6/eyekit/_snap.py)
171
+ implementation
172
+ Chain consecutive fixations that are sufficiently close to each other, and
173
+ then assign chains to their closest text lines. Default params:
174
+ `x_thresh=192`, `y_thresh=32`. Requires NumPy. Original method
175
+ implemented in [popEye](https://github.com/sascha2schroeder/popEye/).
176
+ """
177
+ try:
178
+ import numpy as np
179
+ except ModuleNotFoundError as e:
180
+ e.msg = "The chain method requires NumPy."
181
+ raise
182
+ fixation_XY = np.array(fixation_XY)
183
+ line_Y = np.array(midlines)
184
+ dist_X = abs(np.diff(fixation_XY[:, 0]))
185
+ dist_Y = abs(np.diff(fixation_XY[:, 1]))
186
+ end_chain_indices = list(np.where(np.logical_or(dist_X > x_thresh, dist_Y > y_thresh))[0] + 1)
187
+ end_chain_indices.append(len(fixation_XY))
188
+ start_of_chain = 0
189
+ for end_of_chain in end_chain_indices:
190
+ mean_y = np.mean(fixation_XY[start_of_chain:end_of_chain, 1])
191
+ line_i = np.argmin(abs(line_Y - mean_y))
192
+ fixation_XY[start_of_chain:end_of_chain, 1] = line_Y[line_i]
193
+ start_of_chain = end_of_chain
194
+ return fixation_XY[:, 1]
195
+
196
+
197
+ def cluster(fixation_XY, line_Y):
198
+ m = len(line_Y)
199
+ fixation_Y = fixation_XY[:, 1].reshape(-1, 1)
200
+ clusters = KMeans(m, n_init=100, max_iter=300).fit_predict(fixation_Y)
201
+ centers = [fixation_Y[clusters == i].mean() for i in range(m)]
202
+ ordered_cluster_indices = np.argsort(centers)
203
+ for fixation_i, cluster_i in enumerate(clusters):
204
+ line_i = np.where(ordered_cluster_indices == cluster_i)[0][0]
205
+ fixation_XY[fixation_i, 1] = line_Y[line_i]
206
+ return fixation_XY[:, 1]
207
+
208
+
209
+ def compare(fixation_XY, word_XY, x_thresh=512, n_nearest_lines=3):
210
+ # COMPARE
211
+ #
212
+ # Lima Sanches, C., Kise, K., & Augereau, O. (2015). Eye gaze and text
213
+ # line matching for reading analysis. In Adjunct proceedings of the
214
+ # 2015 ACM International Joint Conference on Pervasive and
215
+ # Ubiquitous Computing and proceedings of the 2015 ACM International
216
+ # Symposium on Wearable Computers (pp. 1227–1233). Association for
217
+ # Computing Machinery.
218
+ #
219
+ # https://doi.org/10.1145/2800835.2807936
220
+ line_Y = np.unique(word_XY[:, 1])
221
+ n = len(fixation_XY)
222
+ diff_X = np.diff(fixation_XY[:, 0])
223
+ end_line_indices = list(np.where(diff_X < -x_thresh)[0] + 1)
224
+ end_line_indices.append(n)
225
+ start_of_line = 0
226
+ for end_of_line in end_line_indices:
227
+ gaze_line = fixation_XY[start_of_line:end_of_line]
228
+ mean_y = np.mean(gaze_line[:, 1])
229
+ lines_ordered_by_proximity = np.argsort(abs(line_Y - mean_y))
230
+ nearest_line_I = lines_ordered_by_proximity[:n_nearest_lines]
231
+ line_costs = np.zeros(n_nearest_lines)
232
+ for candidate_i in range(n_nearest_lines):
233
+ candidate_line_i = nearest_line_I[candidate_i]
234
+ text_line = word_XY[word_XY[:, 1] == line_Y[candidate_line_i]]
235
+ dtw_cost, dtw_path = dynamic_time_warping(gaze_line[:, 0:1], text_line[:, 0:1])
236
+ line_costs[candidate_i] = dtw_cost
237
+ line_i = nearest_line_I[np.argmin(line_costs)]
238
+ fixation_XY[start_of_line:end_of_line, 1] = line_Y[line_i]
239
+ start_of_line = end_of_line
240
+ return fixation_XY[:, 1]
241
+
242
+
243
+ def merge(fixation_XY, midlines, text_right_to_left=False, y_thresh=32, gradient_thresh=0.1, error_thresh=20):
244
+ """
245
+ Form a set of progressive sequences and then reduce the set to *m* by
246
+ repeatedly merging those that appear to be on the same line. Merged
247
+ sequences are then assigned to text lines in positional order. Default
248
+ params: `y_thresh=32`, `gradient_thresh=0.1`, `error_thresh=20`. Requires
249
+ NumPy. Original method by [Špakov et al. (2019)](https://doi.org/10.3758/s13428-018-1120-x).
250
+ """
251
+ try:
252
+ import numpy as np
253
+ except ModuleNotFoundError as e:
254
+ e.msg = "The merge method requires NumPy."
255
+ raise
256
+ fixation_XY = np.array(fixation_XY)
257
+ line_Y = np.array(midlines)
258
+ diff_X = np.diff(fixation_XY[:, 0])
259
+ dist_Y = abs(np.diff(fixation_XY[:, 1]))
260
+ if text_right_to_left:
261
+ sequence_boundaries = list(np.where(np.logical_or(diff_X > 0, dist_Y > y_thresh))[0] + 1)
262
+ else:
263
+ sequence_boundaries = list(np.where(np.logical_or(diff_X < 0, dist_Y > y_thresh))[0] + 1)
264
+ sequence_starts = [0] + sequence_boundaries
265
+ sequence_ends = sequence_boundaries + [len(fixation_XY)]
266
+ sequences = [list(range(start, end)) for start, end in zip(sequence_starts, sequence_ends)]
267
+ for min_i, min_j, remove_constraints in [
268
+ (3, 3, False), # Phase 1
269
+ (1, 3, False), # Phase 2
270
+ (1, 1, False), # Phase 3
271
+ (1, 1, True), # Phase 4
272
+ ]:
273
+ while len(sequences) > len(line_Y):
274
+ best_merger = None
275
+ best_error = np.inf
276
+ for i in range(len(sequences) - 1):
277
+ if len(sequences[i]) < min_i:
278
+ continue # first sequence too short, skip to next i
279
+ for j in range(i + 1, len(sequences)):
280
+ if len(sequences[j]) < min_j:
281
+ continue # second sequence too short, skip to next j
282
+ candidate_XY = fixation_XY[sequences[i] + sequences[j]]
283
+ gradient, intercept = np.polyfit(candidate_XY[:, 0], candidate_XY[:, 1], 1)
284
+ residuals = candidate_XY[:, 1] - (gradient * candidate_XY[:, 0] + intercept)
285
+ error = np.sqrt(sum(residuals**2) / len(candidate_XY))
286
+ if remove_constraints or (abs(gradient) < gradient_thresh and error < error_thresh):
287
+ if error < best_error:
288
+ best_merger = (i, j)
289
+ best_error = error
290
+ if best_merger is None:
291
+ break # no possible mergers, break while and move to next phase
292
+ merge_i, merge_j = best_merger
293
+ merged_sequence = sequences[merge_i] + sequences[merge_j]
294
+ sequences.append(merged_sequence)
295
+ del sequences[merge_j], sequences[merge_i]
296
+ mean_Y = [fixation_XY[sequence, 1].mean() for sequence in sequences]
297
+ ordered_sequence_indices = np.argsort(mean_Y)
298
+ for line_i, sequence_i in enumerate(ordered_sequence_indices):
299
+ fixation_XY[sequences[sequence_i], 1] = line_Y[line_i]
300
+ return fixation_XY[:, 1]
301
+
302
+
303
+ def regress(
304
+ fixation_XY,
305
+ midlines,
306
+ slope_bounds=(-0.1, 0.1),
307
+ offset_bounds=(-50, 50),
308
+ std_bounds=(1, 20),
309
+ ):
310
+ """
311
+ Find *m* regression lines that best fit the fixations and group fixations
312
+ according to best fit regression lines, and then assign groups to text
313
+ lines in positional order. Default params: `slope_bounds=(-0.1, 0.1)`,
314
+ `offset_bounds=(-50, 50)`, `std_bounds=(1, 20)`. Requires SciPy.
315
+ Original method by [Cohen (2013)](https://doi.org/10.3758/s13428-012-0280-3).
316
+ """
317
+ try:
318
+ import numpy as np
319
+ from scipy.optimize import minimize
320
+ from scipy.stats import norm
321
+ except ModuleNotFoundError as e:
322
+ e.msg = "The regress method requires SciPy."
323
+ raise
324
+ fixation_XY = np.array(fixation_XY)
325
+ line_Y = np.array(midlines)
326
+ density = np.zeros((len(fixation_XY), len(line_Y)))
327
+
328
+ def fit_lines(params):
329
+ k = slope_bounds[0] + (slope_bounds[1] - slope_bounds[0]) * norm.cdf(params[0])
330
+ o = offset_bounds[0] + (offset_bounds[1] - offset_bounds[0]) * norm.cdf(params[1])
331
+ s = std_bounds[0] + (std_bounds[1] - std_bounds[0]) * norm.cdf(params[2])
332
+ predicted_Y_from_slope = fixation_XY[:, 0] * k
333
+ line_Y_plus_offset = line_Y + o
334
+ for line_i in range(len(line_Y)):
335
+ fit_Y = predicted_Y_from_slope + line_Y_plus_offset[line_i]
336
+ density[:, line_i] = norm.logpdf(fixation_XY[:, 1], fit_Y, s)
337
+ return -sum(density.max(axis=1))
338
+
339
+ best_fit = minimize(fit_lines, [0, 0, 0], method="powell")
340
+ fit_lines(best_fit.x)
341
+ return line_Y[density.argmax(axis=1)]
342
+
343
+
344
+ def segment(fixation_XY, midlines, text_right_to_left=False):
345
+ """
346
+ Segment fixation sequence into *m* subsequences based on *m*–1 most-likely
347
+ return sweeps, and then assign subsequences to text lines in chronological
348
+ order. Requires NumPy. Original method by
349
+ [Abdulin & Komogortsev (2015)](https://doi.org/10.1109/BTAS.2015.7358786).
350
+ """
351
+ try:
352
+ import numpy as np
353
+ except ModuleNotFoundError as e:
354
+ e.msg = "The segment method requires NumPy."
355
+ raise
356
+ fixation_XY = np.array(fixation_XY)
357
+ line_Y = np.array(midlines)
358
+ diff_X = np.diff(fixation_XY[:, 0])
359
+ saccades_ordered_by_length = np.argsort(diff_X)
360
+ if text_right_to_left:
361
+ line_change_indices = saccades_ordered_by_length[-(len(line_Y) - 1) :]
362
+ else:
363
+ line_change_indices = saccades_ordered_by_length[: len(line_Y) - 1]
364
+ current_line_i = 0
365
+ for fixation_i in range(len(fixation_XY)):
366
+ fixation_XY[fixation_i, 1] = line_Y[current_line_i]
367
+ if fixation_i in line_change_indices:
368
+ current_line_i += 1
369
+ return fixation_XY[:, 1]
370
+
371
+
372
+ def split(fixation_XY, midlines, text_right_to_left=False):
373
+ """
374
+ Split fixation sequence into subsequences based on best candidate return
375
+ sweeps, and then assign subsequences to closest text lines. Requires
376
+ SciPy. Original method by [Carr et al. (2022)](https://doi.org/10.3758/s13428-021-01554-0).
377
+ """
378
+ try:
379
+ import numpy as np
380
+ from scipy.cluster.vq import kmeans2
381
+ except ModuleNotFoundError as e:
382
+ e.msg = "The split method requires SciPy."
383
+ raise
384
+ fixation_XY = np.array(fixation_XY)
385
+ line_Y = np.array(midlines)
386
+ diff_X = np.array(np.diff(fixation_XY[:, 0]), dtype=float).reshape(-1, 1)
387
+ centers, clusters = kmeans2(diff_X, 2, iter=100, minit="++", missing="raise")
388
+ if text_right_to_left:
389
+ sweep_marker = np.argmax(centers)
390
+ else:
391
+ sweep_marker = np.argmin(centers)
392
+ end_line_indices = list(np.where(clusters == sweep_marker)[0] + 1)
393
+ end_line_indices.append(len(fixation_XY))
394
+ start_of_line = 0
395
+ for end_of_line in end_line_indices:
396
+ mean_y = np.mean(fixation_XY[start_of_line:end_of_line, 1])
397
+ line_i = np.argmin(abs(line_Y - mean_y))
398
+ fixation_XY[start_of_line:end_of_line] = line_Y[line_i]
399
+ start_of_line = end_of_line
400
+ return fixation_XY[:, 1]
401
+
402
+
403
+ def stretch(fixation_XY, midlines, stretch_bounds=(0.9, 1.1), offset_bounds=(-50, 50)):
404
+ """
405
+ Find a stretch factor and offset that results in a good alignment between
406
+ the fixations and lines of text, and then assign the transformed fixations
407
+ to the closest text lines. Default params: `stretch_bounds=(0.9, 1.1)`,
408
+ `offset_bounds=(-50, 50)`. Requires SciPy.
409
+ Original method by [Lohmeier (2015)](http://www.monochromata.de/master_thesis/ma1.3.pdf).
410
+ """
411
+ try:
412
+ import numpy as np
413
+ from scipy.optimize import minimize
414
+ except ModuleNotFoundError as e:
415
+ e.msg = "The stretch method requires SciPy."
416
+ raise
417
+ fixation_Y = np.array(fixation_XY)[:, 1]
418
+ line_Y = np.array(midlines)
419
+ n = len(fixation_Y)
420
+ corrected_Y = np.zeros(n)
421
+
422
+ def fit_lines(params):
423
+ candidate_Y = fixation_Y * params[0] + params[1]
424
+ for fixation_i in range(n):
425
+ line_i = np.argmin(abs(line_Y - candidate_Y[fixation_i]))
426
+ corrected_Y[fixation_i] = line_Y[line_i]
427
+ return sum(abs(candidate_Y - corrected_Y))
428
+
429
+ best_fit = minimize(fit_lines, [1, 0], method="powell", bounds=[stretch_bounds, offset_bounds])
430
+ fit_lines(best_fit.x)
431
+ return corrected_Y
432
+
433
+
434
+ def warp(fixation_XY, word_center_list):
435
+ """
436
+ Map fixations to word centers using [Dynamic Time
437
+ Warping](https://en.wikipedia.org/wiki/Dynamic_time_warping). This finds a
438
+ monotonically increasing mapping between fixations and words with the
439
+ shortest overall distance, effectively resulting in *m* subsequences.
440
+ Fixations are then assigned to the lines that their mapped words belong
441
+ to, effectively assigning subsequences to text lines in chronological
442
+ order. Requires NumPy.
443
+ Original method by [Carr et al. (2022)](https://doi.org/10.3758/s13428-021-01554-0).
444
+ """
445
+ try:
446
+ import numpy as np
447
+ except ModuleNotFoundError as e:
448
+ e.msg = "The warp method requires NumPy."
449
+ raise
450
+ fixation_XY = np.array(fixation_XY)
451
+ word_XY = np.array([word_center for word_center in word_center_list])
452
+ n1 = len(fixation_XY)
453
+ n2 = len(word_XY)
454
+ cost = np.zeros((n1 + 1, n2 + 1))
455
+ cost[0, :] = np.inf
456
+ cost[:, 0] = np.inf
457
+ cost[0, 0] = 0
458
+ for fixation_i in range(n1):
459
+ for word_i in range(n2):
460
+ distance = np.sqrt(sum((fixation_XY[fixation_i] - word_XY[word_i]) ** 2))
461
+ cost[fixation_i + 1, word_i + 1] = distance + min(
462
+ cost[fixation_i, word_i + 1],
463
+ cost[fixation_i + 1, word_i],
464
+ cost[fixation_i, word_i],
465
+ )
466
+ cost = cost[1:, 1:]
467
+ warping_path = [[] for _ in range(n1)]
468
+ while fixation_i > 0 or word_i > 0:
469
+ warping_path[fixation_i].append(word_i)
470
+ possible_moves = [np.inf, np.inf, np.inf]
471
+ if fixation_i > 0 and word_i > 0:
472
+ possible_moves[0] = cost[fixation_i - 1, word_i - 1]
473
+ if fixation_i > 0:
474
+ possible_moves[1] = cost[fixation_i - 1, word_i]
475
+ if word_i > 0:
476
+ possible_moves[2] = cost[fixation_i, word_i - 1]
477
+ best_move = np.argmin(possible_moves)
478
+ if best_move == 0:
479
+ fixation_i -= 1
480
+ word_i -= 1
481
+ elif best_move == 1:
482
+ fixation_i -= 1
483
+ else:
484
+ word_i -= 1
485
+ warping_path[0].append(0)
486
+ for fixation_i, words_mapped_to_fixation_i in enumerate(warping_path):
487
+ candidate_Y = list(word_XY[words_mapped_to_fixation_i, 1])
488
+ fixation_XY[fixation_i, 1] = max(set(candidate_Y), key=candidate_Y.count)
489
+ return fixation_XY[:, 1]
490
+
491
+
492
+ def dynamic_time_warping(sequence1, sequence2):
493
+ n1 = len(sequence1)
494
+ n2 = len(sequence2)
495
+ dtw_cost = np.zeros((n1 + 1, n2 + 1))
496
+ dtw_cost[0, :] = np.inf
497
+ dtw_cost[:, 0] = np.inf
498
+ dtw_cost[0, 0] = 0
499
+ for i in range(n1):
500
+ for j in range(n2):
501
+ this_cost = np.sqrt(sum((sequence1[i] - sequence2[j]) ** 2))
502
+ dtw_cost[i + 1, j + 1] = this_cost + min(dtw_cost[i, j + 1], dtw_cost[i + 1, j], dtw_cost[i, j])
503
+ dtw_cost = dtw_cost[1:, 1:]
504
+ dtw_path = [[] for _ in range(n1)]
505
+ while i > 0 or j > 0:
506
+ dtw_path[i].append(j)
507
+ possible_moves = [np.inf, np.inf, np.inf]
508
+ if i > 0 and j > 0:
509
+ possible_moves[0] = dtw_cost[i - 1, j - 1]
510
+ if i > 0:
511
+ possible_moves[1] = dtw_cost[i - 1, j]
512
+ if j > 0:
513
+ possible_moves[2] = dtw_cost[i, j - 1]
514
+ best_move = np.argmin(possible_moves)
515
+ if best_move == 0:
516
+ i -= 1
517
+ j -= 1
518
+ elif best_move == 1:
519
+ i -= 1
520
+ else:
521
+ j -= 1
522
+ dtw_path[0].append(0)
523
+ return dtw_cost[-1, -1], dtw_path
524
+
525
+
526
+ def wisdom_of_the_crowd(assignments):
527
+ """
528
+ For each fixation, choose the y-value with the most votes across multiple
529
+ algorithms. In the event of a tie, the left-most algorithm is given
530
+ priority.
531
+ """
532
+ assignments = np.column_stack(assignments)
533
+ correction = []
534
+ for row in assignments:
535
+ candidates = list(row)
536
+ candidate_counts = {y: candidates.count(y) for y in set(candidates)}
537
+ best_count = max(candidate_counts.values())
538
+ best_candidates = [y for y, c in candidate_counts.items() if c == best_count]
539
+ if len(best_candidates) == 1:
540
+ correction.append(best_candidates[0])
541
+ else:
542
+ for y in row:
543
+ if y in best_candidates:
544
+ correction.append(y)
545
+ break
546
+ return correction
eyekit_measures.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import eyekit as ek
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+
7
+
8
+ MEASURES_DICT = {
9
+ "number_of_fixations": [],
10
+ "initial_fixation_duration": [],
11
+ "first_of_many_duration": [],
12
+ "total_fixation_duration": [],
13
+ "gaze_duration": [],
14
+ "go_past_duration": [],
15
+ "second_pass_duration": [],
16
+ "initial_landing_position": [],
17
+ "initial_landing_distance": [],
18
+ "landing_distances": [],
19
+ "number_of_regressions_in": [],
20
+ }
21
+
22
+
23
+ def get_fix_seq_and_text_block(
24
+ dffix,
25
+ trial,
26
+ x_txt_start=None,
27
+ y_txt_start=None,
28
+ font_face="Courier New",
29
+ font_size=None,
30
+ line_height=None,
31
+ use_corrected_fixations=True,
32
+ correction_algo="warp",
33
+ ):
34
+ if use_corrected_fixations and correction_algo is not None:
35
+ fixations_tuples = [
36
+ (
37
+ (x[1]["x"], x[1][f"y_{correction_algo}"], x[1]["corrected_start_time"], x[1]["corrected_end_time"])
38
+ if x[1]["corrected_start_time"] < x[1]["corrected_end_time"]
39
+ else (x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"] + 1)
40
+ )
41
+ for x in dffix.iterrows()
42
+ ]
43
+ else:
44
+ fixations_tuples = [
45
+ (
46
+ (x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"])
47
+ if x[1]["corrected_start_time"] < x[1]["corrected_end_time"]
48
+ else (x[1]["x"], x[1]["y"], x[1]["corrected_start_time"], x[1]["corrected_end_time"] + 1)
49
+ )
50
+ for x in dffix.iterrows()
51
+ ]
52
+
53
+ try:
54
+ fixation_sequence = ek.FixationSequence(fixations_tuples)
55
+ except Exception as e:
56
+ print(e)
57
+ print(f"Creating fixation failed for {trial['trial_id']} {trial['filename']}")
58
+ return dffix
59
+
60
+ if "display_coords" in trial:
61
+ display_coords = trial["display_coords"]
62
+ else:
63
+ display_coords = (0, 0, 1920, 1080)
64
+ screen_size = ((display_coords[2] - display_coords[0]), (display_coords[3] - display_coords[1]))
65
+
66
+ y_diffs = np.unique(trial["line_heights"])
67
+ if len(y_diffs) == 1:
68
+ y_diff = y_diffs[0]
69
+ else:
70
+ y_diff = np.min(y_diffs)
71
+ chars_list = trial["chars_list"]
72
+ max_line = int(chars_list[-1]["assigned_line"])
73
+ words_on_lines = {x: [] for x in range(int(max_line) + 1)}
74
+ [words_on_lines[x["assigned_line"]].append(x["char"]) for x in chars_list]
75
+ sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()]
76
+
77
+ if x_txt_start is None:
78
+ x_txt_start = float(chars_list[0]["char_xmin"])
79
+ if y_txt_start is None:
80
+ y_txt_start = float(chars_list[0]["char_ymax"])
81
+
82
+ if font_face is None and "font" in trial:
83
+ font_face = trial["font"]
84
+ elif font_face is None:
85
+ font_face = "DejaVu Sans Mono"
86
+
87
+ if font_size is None and "font_size" in trial:
88
+ font_size = trial["font_size"]
89
+ elif font_size is None:
90
+ font_size = float(y_diff * 0.333) # pixel to point conversion
91
+ if line_height is None:
92
+ line_height = float(y_diff)
93
+ textblock = ek.TextBlock(
94
+ sentence_list,
95
+ position=(float(x_txt_start), float(y_txt_start)),
96
+ font_face=font_face,
97
+ line_height=line_height,
98
+ font_size=font_size,
99
+ anchor="left",
100
+ align="left",
101
+ )
102
+
103
+ # eyekit_plot(textblock, fixation_sequence, screen_size)
104
+ ek.io.save(fixation_sequence, f'results/fixation_sequence_eyekit_{trial["trial_id"]}.json', compress=False)
105
+ ek.io.save(textblock, f'results/textblock_eyekit_{trial["trial_id"]}.json', compress=False)
106
+
107
+ return fixation_sequence, textblock, screen_size
108
+
109
+
110
+ def eyekit_plot(textblock, fixation_sequence, screen_size):
111
+ img = ek.vis.Image(*screen_size)
112
+ img.draw_text_block(textblock)
113
+ for word in textblock.words():
114
+ img.draw_rectangle(word, color="hotpink")
115
+ img.draw_fixation_sequence(fixation_sequence)
116
+ img.save("temp_eyekit_img.png", crop_margin=200)
117
+ img_png = Image.open("temp_eyekit_img.png")
118
+ return img_png
119
+
120
+
121
+ def plot_with_measure(textblock, fixation_sequence, screen_size, measure, use_characters=False):
122
+
123
+ eyekitplot_img = eyekit_plot(textblock, fixation_sequence, screen_size)
124
+ eyekitplot_img = ek.vis.Image(*screen_size)
125
+ eyekitplot_img.draw_text_block(textblock)
126
+ if use_characters:
127
+ measure_results = getattr(ek.measure, measure)(textblock.characters(), fixation_sequence)
128
+ enum = textblock.characters()
129
+ else:
130
+ measure_results = getattr(ek.measure, measure)(textblock.words(), fixation_sequence)
131
+ enum = textblock.words()
132
+ for word in enum:
133
+ eyekitplot_img.draw_rectangle(word, color="lightseagreen")
134
+ x = word.onset
135
+ y = word.y_br - 3
136
+ label = f"{measure_results[word.id]}"
137
+ eyekitplot_img.draw_annotation((x, y), label, color="lightseagreen", font_face="Arial bold", font_size=15)
138
+ eyekitplot_img.draw_fixation_sequence(fixation_sequence, color="gray")
139
+ eyekitplot_img.save("multiline_passage_piccol.png", crop_margin=100)
140
+ img_png = Image.open("multiline_passage_piccol.png")
141
+ return img_png
142
+
143
+
144
+ def get_eyekit_measures(_txt, _seq, get_char_measures=False):
145
+ measures = copy.deepcopy(MEASURES_DICT)
146
+ words = []
147
+ for w in _txt.words():
148
+ words.append(w.text)
149
+ for m in measures.keys():
150
+ measures[m].append(getattr(ek.measure, m)(w, _seq))
151
+ word_measures_df = pd.DataFrame(measures)
152
+ word_measures_df["word_number"] = np.arange(0, len(words))
153
+ word_measures_df["word"] = words
154
+
155
+ first_column = word_measures_df.pop("word")
156
+ word_measures_df.insert(0, "word", first_column)
157
+ first_column = word_measures_df.pop("word_number")
158
+ word_measures_df.insert(0, "word_number", first_column)
159
+
160
+ if get_char_measures:
161
+ measures = copy.deepcopy(MEASURES_DICT)
162
+
163
+ characters = []
164
+ for c in _txt.characters():
165
+ characters.append(c.text)
166
+ for m in measures.keys():
167
+ measures[m].append(getattr(ek.measure, m)(c, _seq))
168
+ character_measures_df = pd.DataFrame(measures)
169
+ character_measures_df["char_number"] = np.arange(0, len(characters))
170
+ character_measures_df["character"] = characters
171
+
172
+ first_column = character_measures_df.pop("character")
173
+ character_measures_df.insert(0, "character", first_column)
174
+ first_column = character_measures_df.pop("char_number")
175
+ character_measures_df.insert(0, "char_number", first_column)
176
+ else:
177
+ character_measures_df = None
178
+ return word_measures_df, character_measures_df
loss_functions.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+
3
+
4
+ def macro_soft_f1(real_vals, predictions, reduction):
5
+ """from https://towardsdatascience.com/the-unknown-benefits-of-using-a-soft-f1-loss-in-classification-systems-753902c0105d"""
6
+ true_positive = (real_vals * predictions).sum(dim=0)
7
+ false_positive = (predictions * (1 - real_vals)).sum(dim=0)
8
+ false_negative = ((1 - predictions) * real_vals).sum(dim=0)
9
+ soft_f1 = 2 * true_positive / (2 * true_positive + false_negative + false_positive + 1e-16)
10
+ if reduction == "mean":
11
+ loss = t.mean(1 - soft_f1)
12
+ else:
13
+ loss = 1 - soft_f1
14
+ return loss
15
+
16
+
17
+ def coral_loss(logits, levels, importance_weights=None, reduction="mean"):
18
+ """Computes the CORAL loss described in
19
+ Cao, Mirjalili, and Raschka (2020)
20
+ *Rank Consistent Ordinal Regression for Neural Networks
21
+ with Application to Age Estimation*
22
+ Pattern Recognition Letters, https://doi.org/10.1016/j.patrec.2020.11.008
23
+ Parameters
24
+ ----------
25
+ logits : torch.tensor, shape(num_examples, num_classes-1)
26
+ Outputs of the CORAL layer.
27
+ levels : torch.tensor, shape(num_examples, num_classes-1)
28
+ True labels represented as extended binary vectors
29
+ (via `coral_pytorch.dataset.levels_from_labelbatch`).
30
+ importance_weights : torch.tensor, shape=(num_classes-1,) (default=None)
31
+ Optional weights for the different labels in levels.
32
+ A tensor of ones, i.e.,
33
+ `torch.ones(num_classes-1, dtype=torch.float32)`
34
+ will result in uniform weights that have the same effect as None.
35
+ reduction : str or None (default='mean')
36
+ If 'mean' or 'sum', returns the averaged or summed loss value across
37
+ all data points (rows) in logits. If None, returns a vector of
38
+ shape (num_examples,)
39
+ Returns
40
+ ----------
41
+ loss : torch.tensor
42
+ A torch.tensor containing a single loss value (if `reduction='mean'` or '`sum'`)
43
+ or a loss value for each data record (if `reduction=None`).
44
+ Examples
45
+ ----------
46
+ >>> import torch
47
+ >>> from coral_pytorch.losses import coral_loss
48
+ >>> levels = torch.tensor(
49
+ ... [[1., 1., 0., 0.],
50
+ ... [1., 0., 0., 0.],
51
+ ... [1., 1., 1., 1.]])
52
+ >>> logits = torch.tensor(
53
+ ... [[2.1, 1.8, -2.1, -1.8],
54
+ ... [1.9, -1., -1.5, -1.3],
55
+ ... [1.9, 1.8, 1.7, 1.6]])
56
+ >>> coral_loss(logits, levels)
57
+ tensor(0.6920)
58
+ https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py
59
+ """
60
+
61
+ if not logits.shape == levels.shape:
62
+ raise ValueError(
63
+ "Please ensure that logits (%s) has the same shape as levels (%s). " % (logits.shape, levels.shape)
64
+ )
65
+
66
+ term1 = t.nn.functional.logsigmoid(logits) * levels + (t.nn.functional.logsigmoid(logits) - logits) * (1 - levels)
67
+
68
+ if importance_weights is not None:
69
+ term1 *= importance_weights
70
+
71
+ val = -t.sum(term1, dim=1)
72
+
73
+ if reduction == "mean":
74
+ loss = t.mean(val)
75
+ elif reduction == "sum":
76
+ loss = t.sum(val)
77
+ elif reduction is None:
78
+ loss = val
79
+ else:
80
+ s = 'Invalid value for `reduction`. Should be "mean", ' '"sum", or None. Got %s' % reduction
81
+ raise ValueError(s)
82
+
83
+ return loss
84
+
85
+
86
+ def corn_loss(logits, y_train, num_classes):
87
+ """Computes the CORN loss described in our forthcoming
88
+ 'Deep Neural Networks for Rank Consistent Ordinal
89
+ Regression based on Conditional Probabilities'
90
+ manuscript.
91
+ Parameters
92
+ ----------
93
+ logits : torch.tensor, shape=(num_examples, num_classes-1)
94
+ Outputs of the CORN layer.
95
+ y_train : torch.tensor, shape=(num_examples)
96
+ Torch tensor containing the class labels.
97
+ num_classes : int
98
+ Number of unique class labels (class labels should start at 0).
99
+ Returns
100
+ ----------
101
+ loss : torch.tensor
102
+ A torch.tensor containing a single loss value.
103
+ Examples
104
+ ----------
105
+ >>> import torch
106
+ >>> from coral_pytorch.losses import corn_loss
107
+ >>> # Consider 8 training examples
108
+ >>> _ = torch.manual_seed(123)
109
+ >>> X_train = torch.rand(8, 99)
110
+ >>> y_train = torch.tensor([0, 1, 2, 2, 2, 3, 4, 4])
111
+ >>> NUM_CLASSES = 5
112
+ >>> #
113
+ >>> #
114
+ >>> # def __init__(self):
115
+ >>> corn_net = torch.nn.Linear(99, NUM_CLASSES-1)
116
+ >>> #
117
+ >>> #
118
+ >>> # def forward(self, X_train):
119
+ >>> logits = corn_net(X_train)
120
+ >>> logits.shape
121
+ torch.Size([8, 4])
122
+ >>> corn_loss(logits, y_train, NUM_CLASSES)
123
+ tensor(0.7127, grad_fn=<DivBackward0>)
124
+ https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py
125
+ """
126
+ sets = []
127
+ for i in range(num_classes - 1):
128
+ label_mask = y_train > i - 1
129
+ label_tensor = (y_train[label_mask] > i).to(t.int64)
130
+ sets.append((label_mask, label_tensor))
131
+
132
+ num_examples = 0
133
+ losses = 0.0
134
+ for task_index, s in enumerate(sets):
135
+ train_examples = s[0]
136
+ train_labels = s[1]
137
+
138
+ if len(train_labels) < 1:
139
+ continue
140
+
141
+ num_examples += len(train_labels)
142
+ pred = logits[train_examples, task_index]
143
+
144
+ loss = -t.sum(
145
+ t.nn.functional.logsigmoid(pred) * train_labels
146
+ + (t.nn.functional.logsigmoid(pred) - pred) * (1 - train_labels)
147
+ )
148
+ losses += loss
149
+
150
+ return losses / num_examples
151
+
152
+
153
+ def corn_label_from_logits(logits):
154
+ """
155
+ Returns the predicted rank label from logits for a
156
+ network trained via the CORN loss.
157
+ Parameters
158
+ ----------
159
+ logits : torch.tensor, shape=(n_examples, n_classes)
160
+ Torch tensor consisting of logits returned by the
161
+ neural net.
162
+ Returns
163
+ ----------
164
+ labels : torch.tensor, shape=(n_examples)
165
+ Integer tensor containing the predicted rank (class) labels
166
+ Examples
167
+ ----------
168
+ >>> # 2 training examples, 5 classes
169
+ >>> logits = torch.tensor([[14.152, -6.1942, 0.47710, 0.96850],
170
+ ... [65.667, 0.303, 11.500, -4.524]])
171
+ >>> corn_label_from_logits(logits)
172
+ tensor([1, 3])
173
+ https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/dataset.py
174
+ """
175
+ probas = t.sigmoid(logits)
176
+ probas = t.cumprod(probas, dim=1)
177
+ predict_levels = probas > 0.5
178
+ predicted_labels = t.sum(predict_levels, dim=1)
179
+ return predicted_labels
models.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import os
3
+ from typing import Any
4
+ from pytorch_lightning.utilities.types import LRSchedulerTypeUnion
5
+ import torch as t
6
+ from torch import nn
7
+ import numpy as np
8
+ import transformers
9
+ import pytorch_lightning as plight
10
+ import torchmetrics
11
+ import einops as eo
12
+ from loss_functions import coral_loss, corn_loss, corn_label_from_logits, macro_soft_f1
13
+
14
+ t.set_float32_matmul_precision("medium")
15
+ global_settings = dict(try_using_torch_compile=False)
16
+
17
+
18
+ class EnsembleModel(plight.LightningModule):
19
+ def __init__(self, models_without_norm_df, models_with_norm_df, learning_rate=0.0002, use_simple_average=False):
20
+ super().__init__()
21
+ self.models_without_norm = nn.ModuleList(list(models_without_norm_df))
22
+ self.models_with_norm = nn.ModuleList(list(models_with_norm_df))
23
+ self.learning_rate = learning_rate
24
+ self.use_simple_average = use_simple_average
25
+
26
+ if not self.use_simple_average:
27
+ self.combiner = nn.Linear(
28
+ self.models_with_norm[0].num_classes * (len(self.models_with_norm) + len(self.models_without_norm)),
29
+ self.models_with_norm[0].num_classes,
30
+ )
31
+
32
+ def forward(self, x):
33
+ x_unnormed, x_normed = x
34
+ if not self.use_simple_average:
35
+ out_unnormed = t.cat([model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm], dim=-1)
36
+ out_normed = t.cat([model.model_step(x_normed, 0)[0] for model in self.models_with_norm], dim=-1)
37
+ out_avg = self.combiner(t.cat((out_unnormed, out_normed), dim=-1))
38
+ else:
39
+ out_unnormed = [model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm]
40
+ out_normed = [model.model_step(x_normed, 0)[0] for model in self.models_with_norm]
41
+
42
+ out_avg = (t.stack(out_unnormed + out_normed, dim=-1) / 2).mean(-1)
43
+ return {"out_avg": out_avg, "out_unnormed": out_unnormed, "out_normed": out_normed}, x_unnormed[-1]
44
+
45
+ def training_step(self, batch, batch_idx):
46
+ out, y = self(batch)
47
+ loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
48
+ self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
49
+ return loss
50
+
51
+ def validation_step(self, batch, batch_idx):
52
+ out, y = self(batch)
53
+ preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
54
+ acc = torchmetrics.functional.accuracy(
55
+ preds,
56
+ y_onecold.to(t.long),
57
+ ignore_index=ignore_index_val,
58
+ num_classes=self.models_with_norm[0].num_classes,
59
+ task="multiclass",
60
+ )
61
+ self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
62
+ loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
63
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
64
+ return loss
65
+
66
+ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
67
+ out, y = self(batch)
68
+ preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
69
+ return preds, out, y_onecold
70
+
71
+ def configure_optimizers(self):
72
+ return t.optim.Adam(self.parameters(), lr=self.learning_rate)
73
+
74
+
75
+ class TimmHeadReplace(nn.Module):
76
+ def __init__(self, pooling=None, in_channels=512, pooling_output_dimension=1, all_identity=False) -> None:
77
+ super().__init__()
78
+
79
+ if all_identity:
80
+ self.head = nn.Identity()
81
+ self.pooling = None
82
+ else:
83
+ self.pooling = pooling
84
+ if pooling is not None:
85
+ self.pooling_output_dimension = pooling_output_dimension
86
+ if self.pooling == "AdaptiveAvgPool2d":
87
+ self.pooling_layer = nn.AdaptiveAvgPool2d(pooling_output_dimension)
88
+ elif self.pooling == "AdaptiveMaxPool2d":
89
+ self.pooling_layer = nn.AdaptiveMaxPool2d(pooling_output_dimension)
90
+ self.head = nn.Flatten()
91
+
92
+ def forward(self, x, pre_logits=False):
93
+ if self.pooling is not None:
94
+ if self.pooling == "stack_avg_max_attn":
95
+ x = t.cat([layer(x) for layer in self.pooling_layer], dim=-1)
96
+ else:
97
+ x = self.pooling_layer(x)
98
+ return self.head(x)
99
+
100
+
101
+ class CVModel(nn.Module):
102
+ def __init__(
103
+ self,
104
+ modelname,
105
+ in_shape,
106
+ num_classes,
107
+ loss_func,
108
+ last_activation: str,
109
+ input_padding_val=10,
110
+ char_dims=2,
111
+ max_seq_length=1000,
112
+ ) -> None:
113
+ super().__init__()
114
+ self.modelname = modelname
115
+ self.loss_func = loss_func
116
+ self.in_shape = in_shape
117
+ self.char_dims = char_dims
118
+ self.x_shape = in_shape
119
+ self.last_activation = last_activation
120
+ self.max_seq_length = max_seq_length
121
+ self.num_classes = num_classes
122
+ if self.loss_func == "OrdinalRegLoss":
123
+ self.out_shape = 1
124
+ else:
125
+ self.out_shape = num_classes
126
+
127
+ self.cv_model = timm.create_model(modelname, pretrained=True, num_classes=0)
128
+ self.cv_model.classifier = nn.Identity()
129
+ with t.inference_mode():
130
+ test_out = self.cv_model(t.ones(self.in_shape, dtype=t.float32))
131
+ self.cv_model_out_dim = test_out.shape[1]
132
+ self.cv_model.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.cv_model_out_dim, self.max_seq_length))
133
+ if self.out_shape == 1:
134
+ self.logit_norm = nn.Identity()
135
+ self.out_project = nn.Identity()
136
+ else:
137
+ self.logit_norm = nn.LayerNorm(self.max_seq_length)
138
+ self.out_project = nn.Linear(1, self.out_shape)
139
+
140
+ if last_activation == "Softmax":
141
+ self.final_activation = nn.Softmax(dim=-1)
142
+ elif last_activation == "Sigmoid":
143
+ self.final_activation = nn.Sigmoid()
144
+ elif last_activation == "LogSigmoid":
145
+ self.final_activation = nn.LogSigmoid()
146
+ elif last_activation == "Identity":
147
+ self.final_activation = nn.Identity()
148
+ else:
149
+ raise NotImplementedError(f"{last_activation} not implemented")
150
+
151
+ def forward(self, x):
152
+ if isinstance(x, list):
153
+ x = x[0]
154
+ x = self.cv_model(x)
155
+ x = self.cv_model.classifier(x).unsqueeze(-1)
156
+ x = self.out_project(x)
157
+ return self.final_activation(x)
158
+
159
+
160
+ class LitModel(plight.LightningModule):
161
+ def __init__(
162
+ self,
163
+ in_shape: tuple,
164
+ hidden_dim: int,
165
+ num_attention_heads: int,
166
+ num_layers: int,
167
+ loss_func: str,
168
+ learning_rate: float,
169
+ weight_decay: float,
170
+ cfg: dict,
171
+ use_lr_warmup: bool,
172
+ use_reduce_on_plateau: bool,
173
+ track_gradient_histogram=False,
174
+ register_forw_hook=False,
175
+ char_dims=2,
176
+ ) -> None:
177
+ super().__init__()
178
+ if "only_use_2nd_input_stream" not in cfg:
179
+ cfg["only_use_2nd_input_stream"] = False
180
+
181
+ if "gamma_step_size" not in cfg:
182
+ cfg["gamma_step_size"] = 5
183
+ if "gamma_step_factor" not in cfg:
184
+ cfg["gamma_step_factor"] = 0.5
185
+ self.save_hyperparameters(
186
+ dict(
187
+ in_shape=in_shape,
188
+ hidden_dim=hidden_dim,
189
+ num_attention_heads=num_attention_heads,
190
+ num_layers=num_layers,
191
+ loss_func=loss_func,
192
+ learning_rate=learning_rate,
193
+ cfg=cfg,
194
+ x_shape=in_shape,
195
+ num_classes=cfg["num_classes"],
196
+ use_lr_warmup=use_lr_warmup,
197
+ num_warmup_steps=cfg["num_warmup_steps"],
198
+ use_reduce_on_plateau=use_reduce_on_plateau,
199
+ weight_decay=weight_decay,
200
+ track_gradient_histogram=track_gradient_histogram,
201
+ register_forw_hook=register_forw_hook,
202
+ char_dims=char_dims,
203
+ remove_timm_classifier_head_pooling=cfg["remove_timm_classifier_head_pooling"],
204
+ change_pooling_for_timm_head_to=cfg["change_pooling_for_timm_head_to"],
205
+ chars_conv_pooling_out_dim=cfg["chars_conv_pooling_out_dim"],
206
+ )
207
+ )
208
+ self.model_to_use = cfg["model_to_use"]
209
+ self.num_classes = cfg["num_classes"]
210
+ self.x_shape = in_shape
211
+ self.in_shape = in_shape
212
+ self.hidden_dim = hidden_dim
213
+ self.num_attention_heads = num_attention_heads
214
+ self.num_layers = num_layers
215
+
216
+ self.use_lr_warmup = use_lr_warmup
217
+ self.num_warmup_steps = cfg["num_warmup_steps"]
218
+ self.warmup_exponent = cfg["warmup_exponent"]
219
+
220
+ self.use_reduce_on_plateau = use_reduce_on_plateau
221
+ self.loss_func = loss_func
222
+ self.learning_rate = learning_rate
223
+ self.weight_decay = weight_decay
224
+ self.using_one_hot_targets = cfg["one_hot_y"]
225
+ self.track_gradient_histogram = track_gradient_histogram
226
+ self.register_forw_hook = register_forw_hook
227
+ if self.loss_func == "OrdinalRegLoss":
228
+ self.ord_reg_loss_max = cfg["ord_reg_loss_max"]
229
+ self.ord_reg_loss_min = cfg["ord_reg_loss_min"]
230
+
231
+ self.num_lin_layers = cfg["num_lin_layers"]
232
+ self.linear_activation = cfg["linear_activation"]
233
+ self.last_activation = cfg["last_activation"]
234
+
235
+ self.max_seq_length = cfg["manual_max_sequence_for_model"]
236
+
237
+ self.use_char_embed_info = cfg["use_embedded_char_pos_info"]
238
+
239
+ self.method_chars_into_model = cfg["method_chars_into_model"]
240
+ self.source_for_pretrained_cv_model = cfg["source_for_pretrained_cv_model"]
241
+ self.method_to_include_char_positions = cfg["method_to_include_char_positions"]
242
+
243
+ self.char_dims = char_dims
244
+ self.char_sequence_length = cfg["max_len_chars_list"] if self.use_char_embed_info else 0
245
+
246
+ self.chars_conv_lr_reduction_factor = cfg["chars_conv_lr_reduction_factor"]
247
+ if self.use_char_embed_info:
248
+ self.chars_bert_reduction_factor = cfg["chars_bert_reduction_factor"]
249
+
250
+ self.use_in_projection_bias = cfg["use_in_projection_bias"]
251
+ self.add_layer_norm_to_in_projection = cfg["add_layer_norm_to_in_projection"]
252
+
253
+ self.hidden_dropout_prob = cfg["hidden_dropout_prob"]
254
+ self.layer_norm_after_in_projection = cfg["layer_norm_after_in_projection"]
255
+ self.method_chars_into_model = cfg["method_chars_into_model"]
256
+ self.input_padding_val = cfg["input_padding_val"]
257
+ self.cv_char_modelname = cfg["cv_char_modelname"]
258
+ self.char_plot_shape = cfg["char_plot_shape"]
259
+
260
+ self.remove_timm_classifier_head_pooling = cfg["remove_timm_classifier_head_pooling"]
261
+ self.change_pooling_for_timm_head_to = cfg["change_pooling_for_timm_head_to"]
262
+ self.chars_conv_pooling_out_dim = cfg["chars_conv_pooling_out_dim"]
263
+
264
+ self.add_layer_norm_to_char_mlp = cfg["add_layer_norm_to_char_mlp"]
265
+ if "profile_torch_run" in cfg:
266
+ self.profile_torch_run = cfg["profile_torch_run"]
267
+ else:
268
+ self.profile_torch_run = False
269
+ if self.loss_func == "OrdinalRegLoss":
270
+ self.out_shape = 1
271
+ else:
272
+ self.out_shape = cfg["num_classes"]
273
+
274
+ if not self.hparams.cfg["only_use_2nd_input_stream"]:
275
+ if (
276
+ self.method_chars_into_model == "dense"
277
+ and self.use_char_embed_info
278
+ and self.method_to_include_char_positions == "concat"
279
+ ):
280
+ self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
281
+ elif (
282
+ self.method_chars_into_model == "bert"
283
+ and self.use_char_embed_info
284
+ and self.method_to_include_char_positions == "concat"
285
+ ):
286
+ self.hidden_dim_chars = self.hidden_dim // 2
287
+ self.project = nn.Linear(self.x_shape[-1], self.hidden_dim_chars, bias=self.use_in_projection_bias)
288
+ elif (
289
+ self.method_chars_into_model == "resnet"
290
+ and self.method_to_include_char_positions == "concat"
291
+ and self.use_char_embed_info
292
+ ):
293
+ self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
294
+ elif self.model_to_use == "cv_only_model":
295
+ self.project = nn.Identity()
296
+ else:
297
+ self.project = nn.Linear(self.x_shape[-1], self.hidden_dim, bias=self.use_in_projection_bias)
298
+ if self.add_layer_norm_to_in_projection:
299
+ self.project = nn.Sequential(
300
+ nn.Linear(self.project.in_features, self.project.out_features, bias=self.use_in_projection_bias),
301
+ nn.LayerNorm(self.project.out_features),
302
+ )
303
+
304
+ if hasattr(self, "project") and "posix" in os.name and global_settings["try_using_torch_compile"]:
305
+ self.project = t.compile(self.project)
306
+
307
+ if self.use_char_embed_info:
308
+ self._create_char_model()
309
+
310
+ if self.layer_norm_after_in_projection:
311
+ if self.hparams.cfg["only_use_2nd_input_stream"]:
312
+ self.layer_norm_in = nn.LayerNorm(self.hidden_dim // 2)
313
+ else:
314
+ self.layer_norm_in = nn.LayerNorm(self.hidden_dim)
315
+
316
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
317
+ self.layer_norm_in = t.compile(self.layer_norm_in)
318
+
319
+ self._create_main_seq_model(cfg)
320
+
321
+ if register_forw_hook:
322
+ self.register_hooks()
323
+ if self.hparams.cfg["only_use_2nd_input_stream"]:
324
+ linear_in_dim = self.hidden_dim // 2
325
+ else:
326
+ linear_in_dim = self.hidden_dim
327
+
328
+ if self.num_lin_layers == 1:
329
+ self.linear = nn.Linear(linear_in_dim, self.out_shape)
330
+ else:
331
+ lin_layers = []
332
+ for _ in range(self.num_lin_layers - 1):
333
+ lin_layers.extend(
334
+ [
335
+ nn.Linear(linear_in_dim, linear_in_dim),
336
+ getattr(nn, self.linear_activation)(),
337
+ ]
338
+ )
339
+ self.linear = nn.Sequential(*lin_layers, nn.Linear(linear_in_dim, self.out_shape))
340
+
341
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
342
+ self.linear = t.compile(self.linear)
343
+
344
+ if self.last_activation == "Softmax":
345
+ self.final_activation = nn.Softmax(dim=-1)
346
+ elif self.last_activation == "Sigmoid":
347
+ self.final_activation = nn.Sigmoid()
348
+ elif self.last_activation == "Identity":
349
+ self.final_activation = nn.Identity()
350
+ else:
351
+ raise NotImplementedError(f"{self.last_activation} not implemented")
352
+
353
+ if self.profile_torch_run:
354
+ self.profilerr = t.profiler.profile(
355
+ schedule=t.profiler.schedule(wait=1, warmup=10, active=10, repeat=1),
356
+ on_trace_ready=t.profiler.tensorboard_trace_handler("tblogs"),
357
+ with_stack=True,
358
+ record_shapes=True,
359
+ profile_memory=False,
360
+ )
361
+
362
+ def _create_main_seq_model(self, cfg):
363
+ if self.hparams.cfg["only_use_2nd_input_stream"]:
364
+ hidden_dim = self.hidden_dim // 2
365
+ else:
366
+ hidden_dim = self.hidden_dim
367
+ if self.model_to_use == "BERT":
368
+ self.bert_config = transformers.BertConfig(
369
+ vocab_size=self.x_shape[-1],
370
+ hidden_size=hidden_dim,
371
+ num_hidden_layers=self.num_layers,
372
+ intermediate_size=hidden_dim,
373
+ num_attention_heads=self.num_attention_heads,
374
+ max_position_embeddings=self.max_seq_length,
375
+ )
376
+ self.bert_model = transformers.BertModel(self.bert_config)
377
+ elif self.model_to_use == "cv_only_model":
378
+ self.bert_model = CVModel(
379
+ modelname=cfg["cv_modelname"],
380
+ in_shape=self.in_shape,
381
+ num_classes=cfg["num_classes"],
382
+ loss_func=cfg["loss_function"],
383
+ last_activation=cfg["last_activation"],
384
+ input_padding_val=cfg["input_padding_val"],
385
+ char_dims=self.char_dims,
386
+ max_seq_length=cfg["manual_max_sequence_for_model"],
387
+ )
388
+ else:
389
+ raise NotImplementedError(f"{self.model_to_use} not implemented")
390
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
391
+ self.bert_model = t.compile(self.bert_model)
392
+ return 0
393
+
394
+ def _create_char_model(self):
395
+ if self.method_chars_into_model == "dense":
396
+ self.chars_project_0 = nn.Linear(self.char_dims, 1, bias=self.use_in_projection_bias)
397
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
398
+ self.chars_project_0 = t.compile(self.chars_project_0)
399
+ if self.method_to_include_char_positions == "concat":
400
+ self.chars_project_1 = nn.Linear(
401
+ self.char_sequence_length, self.hidden_dim // 2, bias=self.use_in_projection_bias
402
+ )
403
+ else:
404
+ self.chars_project_1 = nn.Linear(
405
+ self.char_sequence_length, self.hidden_dim, bias=self.use_in_projection_bias
406
+ )
407
+
408
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
409
+ self.chars_project_1 = t.compile(self.chars_project_1)
410
+ elif not self.method_chars_into_model == "resnet":
411
+ self.chars_project = nn.Linear(self.char_dims, self.hidden_dim_chars, bias=self.use_in_projection_bias)
412
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
413
+ self.chars_project = t.compile(self.chars_project)
414
+
415
+ if self.method_chars_into_model == "bert":
416
+ if not hasattr(self, "hidden_dim_chars"):
417
+ if self.hidden_dim // self.chars_bert_reduction_factor > 1:
418
+ self.hidden_dim_chars = self.hidden_dim // self.chars_bert_reduction_factor
419
+ else:
420
+ self.hidden_dim_chars = self.hidden_dim
421
+ self.num_attention_heads_chars = self.hidden_dim_chars // (self.hidden_dim // self.num_attention_heads)
422
+ self.chars_bert_config = transformers.BertConfig(
423
+ vocab_size=self.x_shape[-1],
424
+ hidden_size=self.hidden_dim_chars,
425
+ num_hidden_layers=self.num_layers,
426
+ intermediate_size=self.hidden_dim_chars,
427
+ num_attention_heads=self.num_attention_heads_chars,
428
+ max_position_embeddings=self.char_sequence_length + 1,
429
+ num_labels=1,
430
+ )
431
+ self.chars_bert = transformers.BertForSequenceClassification(self.chars_bert_config)
432
+
433
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
434
+ self.chars_bert = t.compile(self.chars_bert)
435
+ self.chars_project_class_output = nn.Linear(1, self.hidden_dim_chars, bias=self.use_in_projection_bias)
436
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
437
+ self.chars_project_class_output = t.compile(self.chars_project_class_output)
438
+ elif self.method_chars_into_model == "resnet":
439
+ if self.source_for_pretrained_cv_model == "timm":
440
+ self.chars_conv = timm.create_model(
441
+ self.cv_char_modelname,
442
+ pretrained=True,
443
+ num_classes=0, # remove classifier nn.Linear
444
+ )
445
+ if self.remove_timm_classifier_head_pooling:
446
+ self.chars_conv.head = TimmHeadReplace(all_identity=True)
447
+ with t.inference_mode():
448
+ test_out = self.chars_conv(
449
+ t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
450
+ )
451
+ if test_out.ndim > 3:
452
+ self.chars_conv.head = TimmHeadReplace(
453
+ self.change_pooling_for_timm_head_to,
454
+ test_out.shape[1],
455
+ )
456
+ elif self.source_for_pretrained_cv_model == "huggingface":
457
+ self.chars_conv = transformers.AutoModelForImageClassification.from_pretrained(self.cv_char_modelname)
458
+ elif self.source_for_pretrained_cv_model == "torch_hub":
459
+ self.chars_conv = t.hub.load(*self.cv_char_modelname.split(","))
460
+
461
+ if hasattr(self.chars_conv, "classifier"):
462
+ self.chars_conv.classifier = nn.Identity()
463
+ elif hasattr(self.chars_conv, "cls_classifier"):
464
+ self.chars_conv.cls_classifier = nn.Identity()
465
+ elif hasattr(self.chars_conv, "fc"):
466
+ self.chars_conv.fc = nn.Identity()
467
+
468
+ if hasattr(self.chars_conv, "distillation_classifier"):
469
+ self.chars_conv.distillation_classifier = nn.Identity()
470
+ with t.inference_mode():
471
+ test_out = self.chars_conv(
472
+ t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
473
+ )
474
+ if hasattr(test_out, "last_hidden_state"):
475
+ self.chars_conv_out_dim = test_out.last_hidden_state.shape[1]
476
+ elif hasattr(test_out, "logits"):
477
+ self.chars_conv_out_dim = test_out.logits.shape[1]
478
+ elif isinstance(test_out, list):
479
+ self.chars_conv_out_dim = test_out[0].shape[1]
480
+ else:
481
+ self.chars_conv_out_dim = test_out.shape[1]
482
+
483
+ char_lin_layers = [nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)]
484
+ if self.add_layer_norm_to_char_mlp:
485
+ char_lin_layers.append(nn.LayerNorm(self.hidden_dim // 2))
486
+ self.chars_classifier = nn.Sequential(*char_lin_layers)
487
+ if hasattr(self.chars_conv, "distillation_classifier"):
488
+ self.chars_conv.distillation_classifier = nn.Sequential(
489
+ nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)
490
+ )
491
+
492
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
493
+ self.chars_classifier = t.compile(self.chars_classifier)
494
+ if "posix" in os.name and global_settings["try_using_torch_compile"]:
495
+ self.chars_conv = t.compile(self.chars_conv)
496
+ return 0
497
+
498
+ def register_hooks(self):
499
+ def add_to_tb(layer):
500
+ def hook(model, input, output):
501
+ if hasattr(output, "detach"):
502
+ for logger in self.loggers:
503
+ if hasattr(logger.experiment, "add_histogram"):
504
+ logger.experiment.add_histogram(
505
+ tag=f"{layer}_{str(list(output.shape))}",
506
+ values=output.detach(),
507
+ global_step=self.trainer.global_step,
508
+ )
509
+
510
+ return hook
511
+
512
+ for layer_id, layer in dict([*self.named_modules()]).items():
513
+ layer.register_forward_hook(add_to_tb(f"act_{layer_id}"))
514
+
515
+ def on_after_backward(self) -> None:
516
+ if self.track_gradient_histogram:
517
+ if self.trainer.global_step % 200 == 0:
518
+ for logger in self.loggers:
519
+ if hasattr(logger.experiment, "add_histogram"):
520
+ for layer_id, layer in dict([*self.named_modules()]).items():
521
+ parameters = layer.parameters()
522
+ for idx2, p in enumerate(parameters):
523
+ grad_val = p.grad
524
+ if grad_val is not None:
525
+ grad_name = f"grad_{idx2}_{layer_id}_{str(list(p.grad.shape))}"
526
+ logger.experiment.add_histogram(
527
+ tag=grad_name, values=grad_val, global_step=self.trainer.global_step
528
+ )
529
+
530
+ return super().on_after_backward()
531
+
532
+ def _fold_in_seq_dim(self, out, y):
533
+ batch_size, seq_len, num_classes = out.shape
534
+ out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
535
+ if y is None:
536
+ return out, None
537
+ if len(y.shape) > 2:
538
+ y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
539
+ else:
540
+ y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
541
+ return out, y
542
+
543
+ def _get_loss(self, out, y, batch):
544
+ attention_mask = batch[-2]
545
+ if self.loss_func == "BCELoss":
546
+ if self.last_activation == "Identity":
547
+ loss = t.nn.functional.binary_cross_entropy_with_logits(out, y, reduction="none")
548
+ else:
549
+ loss = t.nn.functional.binary_cross_entropy(out, y, reduction="none")
550
+
551
+ replace_tensor = t.zeros(loss[1, 1, :].shape, device=loss.device, dtype=loss.dtype, requires_grad=False)
552
+ loss[~attention_mask.bool()] = replace_tensor
553
+ loss = loss.mean()
554
+ elif self.loss_func == "CrossEntropyLoss":
555
+ if len(out.shape) > 2:
556
+ out, y = self._fold_in_seq_dim(out, y)
557
+ loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
558
+ else:
559
+ loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
560
+
561
+ elif self.loss_func == "OrdinalRegLoss":
562
+ loss = t.nn.functional.mse_loss(out, y, reduction="none")
563
+ loss = loss[attention_mask.bool()].sum() * 10.0 / attention_mask.sum()
564
+ elif self.loss_func == "macro_soft_f1":
565
+ loss = macro_soft_f1(y, out, reduction="mean")
566
+ elif self.loss_func == "coral_loss":
567
+ loss = coral_loss(out, y)
568
+ elif self.loss_func == "corn_loss":
569
+ out, y = self._fold_in_seq_dim(out, y)
570
+ loss = corn_loss(out, y.squeeze(), self.out_shape)
571
+ else:
572
+ raise ValueError("Loss Function not reckognized")
573
+ return loss
574
+
575
+ def training_step(self, batch, batch_idx):
576
+ if self.profile_torch_run:
577
+ self.profilerr.step()
578
+ out, y = self.model_step(batch, batch_idx)
579
+ loss = self._get_loss(out, y, batch)
580
+ self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
581
+ return loss
582
+
583
+ def forward(*args):
584
+ return forward(args[0], args[1:])
585
+
586
+ def model_step(self, batch, batch_idx):
587
+ out = self.forward(batch)
588
+ return out, batch[-1]
589
+
590
+ def optimizer_step(
591
+ self,
592
+ epoch,
593
+ batch_idx,
594
+ optimizer,
595
+ optimizer_closure,
596
+ ):
597
+ optimizer.step(closure=optimizer_closure)
598
+
599
+ if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
600
+ if self.trainer.global_step < self.num_warmup_steps:
601
+ lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.num_warmup_steps) ** self.warmup_exponent
602
+ for pg in optimizer.param_groups:
603
+ pg["lr"] = lr_scale * self.hparams.learning_rate
604
+ if self.trainer.global_step % 10 == 0 or self.trainer.global_step == 0:
605
+ for idx, pg in enumerate(optimizer.param_groups):
606
+ self.log(f"lr_{idx}", pg["lr"], prog_bar=True, sync_dist=True)
607
+
608
+ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None:
609
+ if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
610
+ if self.trainer.global_step > self.num_warmup_steps:
611
+ if metric is None:
612
+ scheduler.step()
613
+ else:
614
+ scheduler.step(metric)
615
+ else:
616
+ if metric is None:
617
+ scheduler.step()
618
+ else:
619
+ scheduler.step(metric)
620
+
621
+ def _get_preds_reals(self, out, y):
622
+ if self.loss_func == "corn_loss":
623
+ seq_len = out.shape[1]
624
+ out, y = self._fold_in_seq_dim(out, y)
625
+ preds = corn_label_from_logits(out)
626
+ preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
627
+ if y is not None:
628
+ y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
629
+
630
+ elif self.loss_func == "OrdinalRegLoss":
631
+ preds = out * (self.ord_reg_loss_max - self.ord_reg_loss_min)
632
+ preds = (preds + self.ord_reg_loss_min).round().to(t.long)
633
+
634
+ else:
635
+ preds = t.argmax(out, dim=-1)
636
+ if y is None:
637
+ return preds, y, -100
638
+ else:
639
+ if self.using_one_hot_targets:
640
+ y_onecold = t.argmax(y, dim=-1)
641
+ ignore_index_val = 0
642
+ elif self.loss_func == "OrdinalRegLoss":
643
+ y_onecold = (y * self.num_classes).round().to(t.long)
644
+
645
+ y_onecold = y * (self.ord_reg_loss_max - self.ord_reg_loss_min)
646
+ y_onecold = (y_onecold + self.ord_reg_loss_min).round().to(t.long)
647
+ ignore_index_val = t.min(y_onecold).to(t.long)
648
+ else:
649
+ y_onecold = y
650
+ ignore_index_val = -100
651
+
652
+ if len(preds.shape) > len(y_onecold.shape):
653
+ preds = preds.squeeze()
654
+ return preds, y_onecold, ignore_index_val
655
+
656
+ def validation_step(self, batch, batch_idx):
657
+ out, y = self.model_step(batch, batch_idx)
658
+ preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
659
+
660
+ if self.loss_func == "OrdinalRegLoss":
661
+ y_onecold = y_onecold.flatten()
662
+ preds = preds.flatten()[y_onecold != ignore_index_val]
663
+ y_onecold = y_onecold[y_onecold != ignore_index_val]
664
+ acc = (preds == y_onecold).sum() / len(y_onecold)
665
+ else:
666
+ acc = torchmetrics.functional.accuracy(
667
+ preds,
668
+ y_onecold.to(t.long),
669
+ ignore_index=ignore_index_val,
670
+ num_classes=self.num_classes,
671
+ task="multiclass",
672
+ )
673
+ self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
674
+ loss = self._get_loss(out, y, batch)
675
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
676
+
677
+ return loss
678
+
679
+ def predict_step(self, batch, batch_idx):
680
+ out, y = self.model_step(batch, batch_idx)
681
+ preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
682
+ return preds, y_onecold
683
+
684
+ def configure_optimizers(self):
685
+ params = list(self.named_parameters())
686
+
687
+ def is_chars_conv(n):
688
+ if "chars_conv" not in n:
689
+ return False
690
+ if "chars_conv" in n and "classifier" in n:
691
+ return False
692
+ else:
693
+ return True
694
+
695
+ grouped_parameters = [
696
+ {
697
+ "params": [p for n, p in params if is_chars_conv(n)],
698
+ "lr": self.learning_rate / self.chars_conv_lr_reduction_factor,
699
+ "weight_decay": self.weight_decay,
700
+ },
701
+ {
702
+ "params": [p for n, p in params if not is_chars_conv(n)],
703
+ "lr": self.learning_rate,
704
+ "weight_decay": self.weight_decay,
705
+ },
706
+ ]
707
+ opti = t.optim.AdamW(grouped_parameters, lr=self.learning_rate, weight_decay=self.weight_decay)
708
+ if self.use_reduce_on_plateau:
709
+ opti_dict = {
710
+ "optimizer": opti,
711
+ "lr_scheduler": {
712
+ "scheduler": t.optim.lr_scheduler.ReduceLROnPlateau(opti, mode="min", patience=2, factor=0.5),
713
+ "monitor": "val_loss",
714
+ "frequency": 1,
715
+ "interval": "epoch",
716
+ },
717
+ }
718
+ return opti_dict
719
+ else:
720
+ cfg = self.hparams["cfg"]
721
+ if cfg["use_reduce_on_plateau"]:
722
+ scheduler = None
723
+ elif cfg["lr_scheduling"] == "multistep":
724
+ scheduler = t.optim.lr_scheduler.MultiStepLR(
725
+ opti, milestones=cfg["multistep_milestones"], gamma=cfg["gamma_multistep"], verbose=False
726
+ )
727
+ interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
728
+ elif cfg["lr_scheduling"] == "StepLR":
729
+ scheduler = t.optim.lr_scheduler.StepLR(
730
+ opti, step_size=cfg["gamma_step_size"], gamma=cfg["gamma_step_factor"]
731
+ )
732
+ interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
733
+ elif cfg["lr_scheduling"] == "anneal":
734
+ scheduler = t.optim.lr_scheduler.CosineAnnealingLR(
735
+ opti, 250, eta_min=cfg["min_lr_anneal"], last_epoch=-1, verbose=False
736
+ )
737
+ interval = "step"
738
+ elif cfg["lr_scheduling"] == "ExponentialLR":
739
+ scheduler = t.optim.lr_scheduler.ExponentialLR(opti, gamma=cfg["lr_sched_exp_fac"])
740
+ interval = "step"
741
+ else:
742
+ scheduler = None
743
+ if scheduler is None:
744
+ return [opti]
745
+ else:
746
+ opti_dict = {
747
+ "optimizer": opti,
748
+ "lr_scheduler": {
749
+ "scheduler": scheduler,
750
+ "monitor": "global_step",
751
+ "frequency": 1,
752
+ "interval": interval,
753
+ },
754
+ }
755
+ return opti_dict
756
+
757
+ def on_fit_start(self) -> None:
758
+ if self.profile_torch_run:
759
+ self.profilerr.start()
760
+ return super().on_fit_start()
761
+
762
+ def on_fit_end(self) -> None:
763
+ if self.profile_torch_run:
764
+ self.profilerr.stop()
765
+ return super().on_fit_end()
766
+
767
+
768
+ def prep_model_input(self, batch):
769
+ if len(batch) == 1:
770
+ batch = batch[0]
771
+ if self.use_char_embed_info:
772
+ if len(batch) == 5:
773
+ x, chars_coords, ims, attention_mask, _ = batch
774
+ elif batch[1].ndim == 4:
775
+ x, ims, attention_mask, _ = batch
776
+ else:
777
+ x, chars_coords, attention_mask, _ = batch
778
+ padding_list = None
779
+ else:
780
+ if len(batch) > 3:
781
+ x = batch[0]
782
+ y = batch[-1]
783
+ attention_mask = batch[1]
784
+ else:
785
+ x, attention_mask, y = batch
786
+
787
+ if self.model_to_use != "cv_only_model" and not self.hparams.cfg["only_use_2nd_input_stream"]:
788
+ x_embedded = self.project(x)
789
+ else:
790
+ x_embedded = x
791
+ if self.use_char_embed_info:
792
+ if self.method_chars_into_model == "dense":
793
+ bool_mask = chars_coords == self.input_padding_val
794
+ bool_mask = bool_mask[:, :, 0]
795
+ chars_coords_projected = self.chars_project_0(chars_coords).squeeze(-1)
796
+ chars_coords_projected = chars_coords_projected * bool_mask
797
+ if self.chars_project_1.in_features == chars_coords_projected.shape[-1]:
798
+ chars_coords_projected = self.chars_project_1(chars_coords_projected)
799
+ else:
800
+ chars_coords_projected = chars_coords_projected.mean(dim=-1)
801
+ chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[2])
802
+ elif self.method_chars_into_model == "bert":
803
+ chars_mask = chars_coords != self.input_padding_val
804
+ chars_mask = t.cat(
805
+ (
806
+ t.ones(chars_mask[:, :1, 0].shape, dtype=t.long, device=chars_coords.device),
807
+ chars_mask[:, :, 0].to(t.long),
808
+ ),
809
+ dim=1,
810
+ )
811
+ chars_coords_projected = self.chars_project(chars_coords)
812
+
813
+ position_ids = t.arange(
814
+ 0, chars_coords_projected.shape[1] + 1, dtype=t.long, device=chars_coords_projected.device
815
+ )
816
+ token_type_ids = t.zeros(
817
+ (chars_coords_projected.size()[0], chars_coords_projected.size()[1] + 1),
818
+ dtype=t.long,
819
+ device=chars_coords_projected.device,
820
+ ) # +1 for CLS
821
+ chars_coords_projected = t.cat(
822
+ (t.ones_like(chars_coords_projected[:, :1, :]), chars_coords_projected), dim=1
823
+ ) # to add CLS token
824
+ chars_coords_projected = self.chars_bert(
825
+ position_ids=position_ids,
826
+ inputs_embeds=chars_coords_projected,
827
+ token_type_ids=token_type_ids,
828
+ attention_mask=chars_mask,
829
+ )
830
+ if hasattr(chars_coords_projected, "last_hidden_state"):
831
+ chars_coords_projected = chars_coords_projected.last_hidden_state[:, 0, :]
832
+ elif hasattr(chars_coords_projected, "logits"):
833
+ chars_coords_projected = chars_coords_projected.logits
834
+ else:
835
+ chars_coords_projected = chars_coords_projected.hidden_states[-1][:, 0, :]
836
+ elif self.method_chars_into_model == "resnet":
837
+ chars_conv_out = self.chars_conv(ims)
838
+ if isinstance(chars_conv_out, list):
839
+ chars_conv_out = chars_conv_out[0]
840
+ if hasattr(chars_conv_out, "logits"):
841
+ chars_conv_out = chars_conv_out.logits
842
+ chars_coords_projected = self.chars_classifier(chars_conv_out)
843
+
844
+ chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[1], 1)
845
+ if hasattr(self, "chars_project_class_output"):
846
+ chars_coords_projected = self.chars_project_class_output(chars_coords_projected)
847
+
848
+ if self.hparams.cfg["only_use_2nd_input_stream"]:
849
+ x_embedded = chars_coords_projected
850
+ elif self.method_to_include_char_positions == "concat":
851
+ x_embedded = t.cat((x_embedded, chars_coords_projected), dim=-1)
852
+ else:
853
+ x_embedded = x_embedded + chars_coords_projected
854
+ return x_embedded, attention_mask
855
+
856
+
857
+ def forward(self, batch):
858
+ prepped_input = prep_model_input(self, batch)
859
+
860
+ if len(batch) > 5:
861
+ x_embedded, padding_list, attention_mask, attention_mask_for_prediction = prepped_input
862
+ elif len(batch) > 2:
863
+ x_embedded, attention_mask = prepped_input
864
+ else:
865
+ x_embedded = prepped_input[0]
866
+ attention_mask = prepped_input[-1]
867
+
868
+ position_ids = t.arange(0, x_embedded.shape[1], dtype=t.long, device=x_embedded.device)
869
+ token_type_ids = t.zeros(x_embedded.size()[:-1], dtype=t.long, device=x_embedded.device)
870
+
871
+ if self.layer_norm_after_in_projection:
872
+ x_embedded = self.layer_norm_in(x_embedded)
873
+
874
+ if self.model_to_use == "LSTM":
875
+ bert_out = self.bert_model(x_embedded)
876
+ elif self.model_to_use in ["ProphetNet", "T5", "FunnelModel"]:
877
+ bert_out = self.bert_model(inputs_embeds=x_embedded, attention_mask=attention_mask)
878
+ elif self.model_to_use == "xBERT":
879
+ bert_out = self.bert_model(x_embedded, mask=attention_mask.to(bool))
880
+ elif self.model_to_use == "cv_only_model":
881
+ bert_out = self.bert_model(x_embedded)
882
+ else:
883
+ bert_out = self.bert_model(
884
+ position_ids=position_ids,
885
+ inputs_embeds=x_embedded,
886
+ token_type_ids=token_type_ids,
887
+ attention_mask=attention_mask,
888
+ )
889
+ if hasattr(bert_out, "last_hidden_state"):
890
+ last_hidden_state = bert_out.last_hidden_state
891
+ out = self.linear(last_hidden_state)
892
+ elif hasattr(bert_out, "logits"):
893
+ out = bert_out.logits
894
+ else:
895
+ out = bert_out
896
+ out = self.final_activation(out)
897
+ return out
models/BERT_20240104-223349_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00430.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c4ae65e81c722f3732563942ab40447a186869bebb1bbc8433a782805e73ac3
3
+ size 86691676
models/BERT_20240104-233803_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00719.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7588696e4afc4c8ffb0ff361d9566b7b360c61a3bb6fd6fcb484942b6d2568b
3
+ size 86692053
models/BERT_20240107-152040_loop_restrict_sim_data_to_4000_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00515.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:815b5500a1ae0a04bb55ae58c3896f07981757a2e1a2adf2cbc8a346551d88df
3
+ size 86686270
models/BERT_20240108-000344_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00706.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2e56e1e33da611622315995e0cdf4db5aad6a086420401ca3ee95393b8977ac
3
+ size 86692053
models/BERT_20240108-011230_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00560.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f060242cf0bc494d2908e0e99e9d411c9a9b131443cff91bb245229dad2f783
3
+ size 86691676
models/BERT_20240109-090419_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00518.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbf23ac7baa88a957e1782158bd7a32aedcfcb0527b203079191ac259ec146c5
3
+ size 86692053
models/BERT_20240122-183729_loop_normalize_by_line_height_and_width_True_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00523.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fb7c8238752af51b64a23291080bb30edf9e090defcb2ec4015ddc8d543a9de
3
+ size 86691740
models/BERT_20240122-194041_loop_normalize_by_line_height_and_width_False_dataset_folder_idx_evaluation_8_epoch=41-val_loss=0.00462.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54fedcc5bdeda01bfae26bafcb7542c766807f1af9da7731aaa7ed38e93743d8
3
+ size 86692117
models/BERT_fin_exp_20240104-223349.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: false
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: false
47
+ normalize_by_line_height_and_width: true
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 0.7326
71
+ - 6.6381
72
+ - 2.4717
73
+ sample_std:
74
+ - 0.2778
75
+ - 1.882
76
+ - 1.8562
77
+ sample_std_unscaled:
78
+ - 285.193
79
+ - 131.1842
80
+ - 1.8562
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240104-233803.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: false
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: false
47
+ normalize_by_line_height_and_width: false
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 710.6114
71
+ - 473.7518
72
+ - 2.4717
73
+ sample_std:
74
+ - 285.1937
75
+ - 131.1842
76
+ - 1.8562
77
+ sample_std_unscaled:
78
+ - 285.193
79
+ - 131.1842
80
+ - 1.8562
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240107-152040.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: false
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: true
47
+ normalize_by_line_height_and_width: true
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 0.4423
71
+ - 3.1164
72
+ - 2.4717
73
+ sample_std:
74
+ - 0.2778
75
+ - 1.882
76
+ - 1.8562
77
+ sample_std_unscaled:
78
+ - 285.193
79
+ - 131.1842
80
+ - 1.8562
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240108-000344.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: true
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: true
47
+ normalize_by_line_height_and_width: false
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 455.5905
71
+ - 218.0598
72
+ - 2.4717
73
+ sample_std:
74
+ - 285.1936
75
+ - 131.1842
76
+ - 1.8562
77
+ sample_std_unscaled:
78
+ - 285.1939
79
+ - 131.1844
80
+ - 1.8562
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240108-011230.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: true
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: true
47
+ normalize_by_line_height_and_width: true
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 0.4423
71
+ - 3.1164
72
+ - 2.4717
73
+ sample_std:
74
+ - 0.2778
75
+ - 1.882
76
+ - 1.8562
77
+ sample_std_unscaled:
78
+ - 285.1939
79
+ - 131.1844
80
+ - 1.8562
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240109-090419.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
6
+ char_dims: 0
7
+ char_plot_shape:
8
+ - 224
9
+ - 224
10
+ chars_bert_reduction_factor: 4
11
+ chars_conv_lr_reduction_factor: 1
12
+ chars_conv_pooling_out_dim: 1
13
+ convert_posix: false
14
+ convert_winpath: true
15
+ cv_char_modelname: coatnet_nano_rw_224
16
+ cv_modelname: null
17
+ early_stopping_patience: 15
18
+ gamma_multistep: null
19
+ gamma_step_factor: 0.5
20
+ gamma_step_size: 3000
21
+ head_multiplication_factor: 64
22
+ hidden_dim_bert: 512
23
+ hidden_dropout_prob: 0.0
24
+ im_partial_string: fixations_chars_channel_sep
25
+ input_padding_val: 10
26
+ last_activation: Identity
27
+ layer_norm_after_in_projection: true
28
+ linear_activation: GELU
29
+ load_best_checkpoint_at_end: false
30
+ loss_function: corn_loss
31
+ lr: 0.0004
32
+ lr_initial: '0.0004'
33
+ lr_sched_exp_fac: null
34
+ lr_scheduling: StepLR
35
+ manual_max_sequence_for_model: 500
36
+ max_len_chars_list: 0
37
+ max_seq_length: 500
38
+ method_chars_into_model: resnet
39
+ method_to_include_char_positions: concat
40
+ min_lr_anneal: 1e-6
41
+ model_to_use: BERT
42
+ multistep_milestones: null
43
+ n_layers_BERT: 4
44
+ norm_by_char_averages: false
45
+ norm_by_line_width: false
46
+ norm_coords_by_letter_min_x_y: true
47
+ normalize_by_line_height_and_width: false
48
+ num_attention_heads: 8
49
+ num_classes: 16
50
+ num_lin_layers: 1
51
+ num_warmup_steps: 3000
52
+ one_hot_y: false
53
+ ord_reg_loss_max: 16
54
+ ord_reg_loss_min: -1
55
+ padding_at_end: true
56
+ plot_histogram: true
57
+ plot_learning_curves: true
58
+ precision: 16-mixed
59
+ prediction_only: false
60
+ pretrained_model_name_to_load: null
61
+ profile_torch_run: false
62
+ reload_model: false
63
+ reload_model_date: null
64
+ remove_eval_idx_from_train_idx: true
65
+ remove_timm_classifier_head_pooling: true
66
+ sample_cols:
67
+ - x
68
+ - y
69
+ sample_means:
70
+ - 455.708
71
+ - 217.8342
72
+ - 2.4706
73
+ sample_std:
74
+ - 285.2534
75
+ - 131.0263
76
+ - 1.8542
77
+ sample_std_unscaled:
78
+ - 285.2527
79
+ - 131.0262
80
+ - 1.8543
81
+ save_weights_only: true
82
+ set_max_seq_len_manually: true
83
+ set_num_classes_manually: true
84
+ source_for_pretrained_cv_model: timm
85
+ target_padding_number: -100
86
+ track_activations_via_hook: false
87
+ track_gradient_histogram: false
88
+ use_char_bounding_boxes: true
89
+ use_early_stopping: false
90
+ use_embedded_char_pos_info: true
91
+ use_fixation_duration_information: false
92
+ use_in_projection_bias: false
93
+ use_lr_warmup: true
94
+ use_pupil_size_information: false
95
+ use_reduce_on_plateau: false
96
+ use_start_time_as_input_col: false
97
+ use_training_steps_for_end_and_lr_decay: true
98
+ use_words_coords: false
99
+ warmup_exponent: 1
100
+ weight_decay: 0.0
models/BERT_fin_exp_20240122-183729.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ add_woc_feature: false
6
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
7
+ char_dims: 0
8
+ char_plot_shape:
9
+ - 224
10
+ - 224
11
+ chars_bert_reduction_factor: 4
12
+ chars_conv_lr_reduction_factor: 1
13
+ chars_conv_pooling_out_dim: 1
14
+ convert_posix: false
15
+ convert_winpath: false
16
+ cv_char_modelname: coatnet_nano_rw_224
17
+ cv_modelname: null
18
+ early_stopping_patience: 15
19
+ gamma_multistep: null
20
+ gamma_step_factor: 0.5
21
+ gamma_step_size: 3000
22
+ head_multiplication_factor: 64
23
+ hidden_dim_bert: 512
24
+ hidden_dropout_prob: 0.0
25
+ im_partial_string: fixations_chars_channel_sep
26
+ input_padding_val: 10
27
+ last_activation: Identity
28
+ layer_norm_after_in_projection: true
29
+ linear_activation: GELU
30
+ load_best_checkpoint_at_end: false
31
+ loss_function: corn_loss
32
+ lr: 0.0004
33
+ lr_initial: '0.0004'
34
+ lr_sched_exp_fac: null
35
+ lr_scheduling: StepLR
36
+ manual_max_sequence_for_model: 500
37
+ max_len_chars_list: 0
38
+ max_seq_length: 500
39
+ method_chars_into_model: resnet
40
+ method_to_include_char_positions: concat
41
+ min_lr_anneal: 1e-6
42
+ model_to_use: BERT
43
+ multistep_milestones: null
44
+ n_layers_BERT: 4
45
+ norm_by_char_averages: false
46
+ norm_by_line_width: false
47
+ norm_coords_by_letter_min_x_y: true
48
+ normalize_by_line_height_and_width: true
49
+ num_attention_heads: 8
50
+ num_classes: 16
51
+ num_lin_layers: 1
52
+ num_warmup_steps: 3000
53
+ one_hot_y: false
54
+ only_use_2nd_input_stream: false
55
+ ord_reg_loss_max: 16
56
+ ord_reg_loss_min: -1
57
+ padding_at_end: true
58
+ plot_histogram: true
59
+ plot_learning_curves: true
60
+ precision: 16-mixed
61
+ prediction_only: false
62
+ pretrained_model_name_to_load: null
63
+ profile_torch_run: false
64
+ reload_model: false
65
+ reload_model_date: null
66
+ remove_eval_idx_from_train_idx: true
67
+ remove_timm_classifier_head_pooling: true
68
+ sample_cols:
69
+ - x
70
+ - y
71
+ sample_means:
72
+ - 0.4433
73
+ - 2.9599
74
+ - 2.3264
75
+ sample_std:
76
+ - 0.2782
77
+ - 1.7872
78
+ - 1.7619
79
+ sample_std_unscaled:
80
+ - 287.0107
81
+ - 124.4113
82
+ - 1.7619
83
+ save_weights_only: true
84
+ set_max_seq_len_manually: true
85
+ set_num_classes_manually: true
86
+ source_for_pretrained_cv_model: timm
87
+ target_padding_number: -100
88
+ track_activations_via_hook: false
89
+ track_gradient_histogram: false
90
+ use_char_bounding_boxes: true
91
+ use_early_stopping: false
92
+ use_embedded_char_pos_info: true
93
+ use_fixation_duration_information: false
94
+ use_in_projection_bias: false
95
+ use_lr_warmup: true
96
+ use_pupil_size_information: false
97
+ use_reduce_on_plateau: false
98
+ use_start_time_as_input_col: false
99
+ use_training_steps_for_end_and_lr_decay: true
100
+ use_words_coords: false
101
+ warmup_exponent: 1
102
+ weight_decay: 0.0
models/BERT_fin_exp_20240122-194041.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_layer_norm_to_char_mlp: true
2
+ add_layer_norm_to_in_projection: false
3
+ add_line_overlap_feature: true
4
+ add_normalised_values_as_features: false
5
+ add_woc_feature: false
6
+ change_pooling_for_timm_head_to: AdaptiveAvgPool2d
7
+ char_dims: 0
8
+ char_plot_shape:
9
+ - 224
10
+ - 224
11
+ chars_bert_reduction_factor: 4
12
+ chars_conv_lr_reduction_factor: 1
13
+ chars_conv_pooling_out_dim: 1
14
+ convert_posix: false
15
+ convert_winpath: false
16
+ cv_char_modelname: coatnet_nano_rw_224
17
+ cv_modelname: null
18
+ early_stopping_patience: 15
19
+ gamma_multistep: null
20
+ gamma_step_factor: 0.5
21
+ gamma_step_size: 3000
22
+ head_multiplication_factor: 64
23
+ hidden_dim_bert: 512
24
+ hidden_dropout_prob: 0.0
25
+ im_partial_string: fixations_chars_channel_sep
26
+ input_padding_val: 10
27
+ last_activation: Identity
28
+ layer_norm_after_in_projection: true
29
+ linear_activation: GELU
30
+ load_best_checkpoint_at_end: false
31
+ loss_function: corn_loss
32
+ lr: 0.0004
33
+ lr_initial: '0.0004'
34
+ lr_sched_exp_fac: null
35
+ lr_scheduling: StepLR
36
+ manual_max_sequence_for_model: 500
37
+ max_len_chars_list: 0
38
+ max_seq_length: 500
39
+ method_chars_into_model: resnet
40
+ method_to_include_char_positions: concat
41
+ min_lr_anneal: 1e-6
42
+ model_to_use: BERT
43
+ multistep_milestones: null
44
+ n_layers_BERT: 4
45
+ norm_by_char_averages: false
46
+ norm_by_line_width: false
47
+ norm_coords_by_letter_min_x_y: true
48
+ normalize_by_line_height_and_width: false
49
+ num_attention_heads: 8
50
+ num_classes: 16
51
+ num_lin_layers: 1
52
+ num_warmup_steps: 3000
53
+ one_hot_y: false
54
+ only_use_2nd_input_stream: false
55
+ ord_reg_loss_max: 16
56
+ ord_reg_loss_min: -1
57
+ padding_at_end: true
58
+ plot_histogram: true
59
+ plot_learning_curves: true
60
+ precision: 16-mixed
61
+ prediction_only: false
62
+ pretrained_model_name_to_load: null
63
+ profile_torch_run: false
64
+ reload_model: false
65
+ reload_model_date: null
66
+ remove_eval_idx_from_train_idx: true
67
+ remove_timm_classifier_head_pooling: true
68
+ sample_cols:
69
+ - x
70
+ - y
71
+ sample_means:
72
+ - 459.3367
73
+ - 206.88
74
+ - 2.3264
75
+ sample_std:
76
+ - 287.0111
77
+ - 124.4113
78
+ - 1.7619
79
+ sample_std_unscaled:
80
+ - 287.0107
81
+ - 124.4113
82
+ - 1.7619
83
+ save_weights_only: true
84
+ set_max_seq_len_manually: true
85
+ set_num_classes_manually: true
86
+ source_for_pretrained_cv_model: timm
87
+ target_padding_number: -100
88
+ track_activations_via_hook: false
89
+ track_gradient_histogram: false
90
+ use_char_bounding_boxes: true
91
+ use_early_stopping: false
92
+ use_embedded_char_pos_info: true
93
+ use_fixation_duration_information: false
94
+ use_in_projection_bias: false
95
+ use_lr_warmup: true
96
+ use_pupil_size_information: false
97
+ use_reduce_on_plateau: false
98
+ use_start_time_as_input_col: false
99
+ use_training_steps_for_end_and_lr_decay: true
100
+ use_words_coords: false
101
+ warmup_exponent: 1
102
+ weight_decay: 0.0
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ einops
3
+ matplotlib
4
+ numpy
5
+ pandas
6
+ PyYAML
7
+ seaborn
8
+ tqdm
9
+ transformers==4.30.2
10
+ tensorboard
11
+ torchmetrics
12
+ pytorch-lightning
13
+ scikit-learn
14
+ plotly
15
+ lovely-tensors
16
+ timm
17
+ openpyxl
18
+ torch==2.*
19
+ pydantic==1.10
20
+ streamlit
21
+ pycairo
22
+ eyekit
23
+ stqdm
24
+ jellyfish
25
+ icecream
run_in_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,2016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import os
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from torch.utils.data.dataloader import DataLoader as dl
6
+ import yaml
7
+ from io import StringIO
8
+ import torch as t
9
+ import numpy as np
10
+ import pandas as pd
11
+ from torch.utils.data import Dataset as torch_dset
12
+ from PIL import Image
13
+ import torchvision.transforms.functional as tvfunc
14
+ import json
15
+ from matplotlib import pyplot as plt
16
+ import matplotlib.patches as patches
17
+ from matplotlib.font_manager import FontProperties
18
+ import pathlib as pl
19
+ import matplotlib as mpl
20
+ import streamlit as st
21
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
22
+ import einops as eo
23
+ import copy
24
+
25
+ # import stqdm
26
+ from tqdm.auto import tqdm
27
+ import time
28
+ import requests
29
+
30
+ from matplotlib.patches import Rectangle
31
+ from matplotlib import font_manager
32
+ from models import LitModel, EnsembleModel
33
+ from loss_functions import corn_label_from_logits
34
+ import classic_correction_algos as calgo
35
+ import analysis_funcs as anf
36
+
37
+ TEMP_FOLDER = pl.Path("results")
38
+ AVAILABLE_FONTS = [x.name for x in font_manager.fontManager.ttflist]
39
+ PLOTS_FOLDER = pl.Path("plots")
40
+ TEMP_FIGURE_STIMULUS_PATH = PLOTS_FOLDER / "temp_matplotlib_plot_stimulus.png"
41
+ all_fonts = [x.name for x in font_manager.fontManager.ttflist]
42
+ mpl.use("agg")
43
+
44
+ DIST_MODELS_FOLDER = pl.Path("models")
45
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
46
+ IMAGENET_STD = [0.229, 0.224, 0.225]
47
+ gradio_plots = pl.Path("plots")
48
+
49
+ event_strs = [
50
+ "EFIX",
51
+ "EFIX R",
52
+ "EFIX L",
53
+ "SSACC",
54
+ "ESACC",
55
+ "SFIX",
56
+ "MSG",
57
+ "SBLINK",
58
+ "EBLINK",
59
+ "BUTTON",
60
+ "INPUT",
61
+ "END",
62
+ "START",
63
+ "DISPLAY ON",
64
+ ]
65
+ names_dict = {
66
+ "SSACC": {"Descr": "Start of Saccade", "Pattern": "SSACC <eye > <stime>"},
67
+ "ESACC": {
68
+ "Descr": "End of Saccade",
69
+ "Pattern": "ESACC <eye > <stime> <etime > <dur> <sxp > <syp> <exp > <eyp> <ampl > <pv >",
70
+ },
71
+ "SFIX": {"Descr": "Start of Fixation", "Pattern": "SFIX <eye > <stime>"},
72
+ "EFIX": {"Descr": "End of Fixation", "Pattern": "EFIX <eye > <stime> <etime > <dur> <axp > <ayp> <aps >"},
73
+ "SBLINK": {"Descr": "Start of Blink", "Pattern": "SBLINK <eye > <stime>"},
74
+ "EBLINK": {"Descr": "End of Blink", "Pattern": "EBLINK <eye > <stime> <etime > <dur>"},
75
+ "DISPLAY ON": {"Descr": "Actual start of Trial", "Pattern": "DISPLAY ON"},
76
+ }
77
+ metadata_strs = ["DISPLAY COORDS", "GAZE_COORDS", "FRAMERATE"]
78
+
79
+ ALGO_CHOICES = st.session_state["ALGO_CHOICES"] = [
80
+ "warp",
81
+ "regress",
82
+ "compare",
83
+ "attach",
84
+ "segment",
85
+ "split",
86
+ "stretch",
87
+ "chain",
88
+ "slice",
89
+ "cluster",
90
+ "merge",
91
+ "Wisdom_of_Crowds",
92
+ "DIST",
93
+ "DIST-Ensemble",
94
+ "Wisdom_of_Crowds_with_DIST",
95
+ "Wisdom_of_Crowds_with_DIST_Ensemble",
96
+ ]
97
+ COLORS = px.colors.qualitative.Alphabet
98
+
99
+
100
+ class NumpyEncoder(json.JSONEncoder):
101
+ "From https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable"
102
+
103
+ def default(self, obj):
104
+ if isinstance(obj, np.ndarray):
105
+ return obj.tolist()
106
+ elif isinstance(obj, pl.Path) or isinstance(obj, UploadedFile):
107
+ return str(obj)
108
+ return json.JSONEncoder.default(self, obj)
109
+
110
+
111
+ class DSet(torch_dset):
112
+ def __init__(
113
+ self,
114
+ in_sequence: t.Tensor,
115
+ chars_center_coords_padded: t.Tensor,
116
+ out_categories: t.Tensor,
117
+ trialslist: list,
118
+ padding_list: list = None,
119
+ padding_at_end: bool = False,
120
+ return_images_for_conv: bool = False,
121
+ im_partial_string: str = "fixations_chars_channel_sep",
122
+ input_im_shape=[224, 224],
123
+ ) -> None:
124
+ super().__init__()
125
+
126
+ self.in_sequence = in_sequence
127
+ self.chars_center_coords_padded = chars_center_coords_padded
128
+ self.out_categories = out_categories
129
+ self.padding_list = padding_list
130
+ self.padding_at_end = padding_at_end
131
+ self.trialslist = trialslist
132
+ self.return_images_for_conv = return_images_for_conv
133
+ self.input_im_shape = input_im_shape
134
+ if return_images_for_conv:
135
+ self.im_partial_string = im_partial_string
136
+ self.plot_files = [
137
+ str(x["plot_file"]).replace("fixations_words", im_partial_string) for x in self.trialslist
138
+ ]
139
+
140
+ def __getitem__(self, index):
141
+
142
+ if self.return_images_for_conv:
143
+ im = Image.open(self.plot_files[index])
144
+ if [im.size[1], im.size[0]] != self.input_im_shape:
145
+ im = tvfunc.resize(im, self.input_im_shape)
146
+ im = tvfunc.normalize(tvfunc.to_tensor(im), IMAGENET_MEAN, IMAGENET_STD)
147
+ if self.chars_center_coords_padded is not None:
148
+ if self.padding_list is not None:
149
+ attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
150
+ if self.padding_at_end:
151
+ if self.padding_list[index] > 0:
152
+ attention_mask[-self.padding_list[index] :] = 0
153
+ else:
154
+ attention_mask[: self.padding_list[index]] = 0
155
+ if self.return_images_for_conv:
156
+ return (
157
+ self.in_sequence[index],
158
+ self.chars_center_coords_padded[index],
159
+ im,
160
+ attention_mask,
161
+ self.out_categories[index],
162
+ )
163
+ return (
164
+ self.in_sequence[index],
165
+ self.chars_center_coords_padded[index],
166
+ attention_mask,
167
+ self.out_categories[index],
168
+ )
169
+ else:
170
+ if self.return_images_for_conv:
171
+ return (
172
+ self.in_sequence[index],
173
+ self.chars_center_coords_padded[index],
174
+ im,
175
+ self.out_categories[index],
176
+ )
177
+ else:
178
+ return (self.in_sequence[index], self.chars_center_coords_padded[index], self.out_categories[index])
179
+
180
+ if self.padding_list is not None:
181
+ attention_mask = t.ones(self.in_sequence[index].shape[:-1], dtype=t.long)
182
+ if self.padding_at_end:
183
+ if self.padding_list[index] > 0:
184
+ attention_mask[-self.padding_list[index] :] = 0
185
+ else:
186
+ attention_mask[: self.padding_list[index]] = 0
187
+ if self.return_images_for_conv:
188
+ return (self.in_sequence[index], im, attention_mask, self.out_categories[index])
189
+ else:
190
+ return (self.in_sequence[index], attention_mask, self.out_categories[index])
191
+ if self.return_images_for_conv:
192
+ return (self.in_sequence[index], im, self.out_categories[index])
193
+ else:
194
+ return (self.in_sequence[index], self.out_categories[index])
195
+
196
+ def __len__(self):
197
+ if isinstance(self.in_sequence, t.Tensor):
198
+ return self.in_sequence.shape[0]
199
+ else:
200
+ return len(self.in_sequence)
201
+
202
+
203
+ def download_url(url, target_filename):
204
+ r = requests.get(url)
205
+ open(target_filename, "wb").write(r.content)
206
+ return 0
207
+
208
+
209
+ def asc_to_trial_ids(asc_file, close_gap_between_words=True):
210
+ if "logger" in st.session_state:
211
+ st.session_state["logger"].debug("asc_to_trial_ids entered")
212
+ asc_encoding = ["ISO-8859-15", "UTF-8"][0]
213
+ trials_dict, lines = file_to_trials_and_lines(
214
+ asc_file, asc_encoding, close_gap_between_words=close_gap_between_words
215
+ )
216
+
217
+ trials_by_ids = {trials_dict[idx]["trial_id"]: trials_dict[idx] for idx in trials_dict["paragraph_trials"]}
218
+ if hasattr(asc_file, "name"):
219
+ if "logger" in st.session_state:
220
+ st.session_state["logger"].info(f"Found {len(trials_by_ids)} trials in {asc_file.name}.")
221
+ return trials_by_ids, lines
222
+
223
+
224
+ def get_trials_list(asc_file=None, close_gap_between_words=True):
225
+ if "logger" in st.session_state:
226
+ st.session_state["logger"].debug("get_trials_list entered")
227
+
228
+ if asc_file == None:
229
+ if "single_asc_file" in st.session_state.keys() and st.session_state["single_asc_file"] is not None:
230
+ asc_file = st.session_state["single_asc_file"]
231
+ else:
232
+ if "logger" in st.session_state:
233
+ st.session_state["logger"].warning("Asc file is None")
234
+ return None
235
+
236
+ if hasattr(asc_file, "name"):
237
+ if "logger" in st.session_state:
238
+ st.session_state["logger"].info(f"get_trials_list entered with asc_file {asc_file.name}")
239
+
240
+ trials_by_ids, lines = asc_to_trial_ids(asc_file, close_gap_between_words=close_gap_between_words)
241
+ trial_keys = list(trials_by_ids.keys())
242
+
243
+ return trial_keys, trials_by_ids, lines, asc_file
244
+
245
+
246
+ def save_trial_to_json(trial, savename):
247
+ if "dffix" in trial:
248
+ trial.pop("dffix")
249
+ with open(savename, "w", encoding="utf-8") as f:
250
+ json.dump(trial, f, ensure_ascii=False, indent=4, cls=NumpyEncoder)
251
+
252
+
253
+ def export_csv(dffix, trial):
254
+ if isinstance(dffix, dict):
255
+ dffix = dffix["value"]
256
+ trial_id = trial["trial_id"]
257
+ savename = TEMP_FOLDER.joinpath(pl.Path(trial["fname"]).stem)
258
+ trial_name = f"{savename}_{trial_id}_trial_info.json"
259
+ csv_name = f"{savename}_{trial_id}.csv"
260
+ dffix.to_csv(csv_name)
261
+ if "logger" in st.session_state:
262
+ st.session_state["logger"].info(f"Saved processed data as {csv_name}")
263
+ save_trial_to_json(trial, trial_name)
264
+ if "logger" in st.session_state:
265
+ st.session_state["logger"].info(f"Saved processed trial data as {trial_name}")
266
+
267
+ return csv_name, trial_name
268
+
269
+
270
+ def get_all_classic_preds(dffix, trial, classic_algos_cfg):
271
+ corrections = []
272
+ for algo, classic_params in copy.deepcopy(classic_algos_cfg).items():
273
+ dffix = calgo.apply_classic_algo(dffix, trial, algo, classic_params)
274
+ corrections.append(np.asarray(dffix.loc[:, f"y_{algo}"]))
275
+ return dffix, corrections
276
+
277
+
278
+ def apply_woc(dffix, trial, corrections, algo_choice):
279
+
280
+ corrected_Y = calgo.wisdom_of_the_crowd(corrections)
281
+ dffix.loc[:, f"y_{algo_choice}"] = corrected_Y
282
+ dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
283
+ corrected_line_nums = [trial["y_char_unique"].index(y) for y in corrected_Y]
284
+ dffix.loc[:, f"line_num_y_{algo_choice}"] = corrected_line_nums
285
+ return dffix
286
+
287
+
288
+ def calc_xdiff_ydiff(line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False):
289
+ x_diffs = np.unique(np.diff(line_xcoords_no_pad))
290
+ if len(x_diffs) == 1:
291
+ x_diff = x_diffs[0]
292
+ elif not allow_multiple_values:
293
+ x_diff = np.min(x_diffs)
294
+ else:
295
+ x_diff = x_diffs
296
+
297
+ if np.unique(line_ycoords_no_pad).shape[0] == 1:
298
+ return x_diff, line_heights[0]
299
+ y_diffs = np.unique(np.diff(line_ycoords_no_pad))
300
+ if len(y_diffs) == 1:
301
+ y_diff = y_diffs[0]
302
+ elif len(y_diffs) == 0:
303
+ y_diff = 0
304
+ elif not allow_multiple_values:
305
+ y_diff = np.min(y_diffs)
306
+ else:
307
+ y_diff = y_diffs
308
+ return x_diff, y_diff
309
+
310
+
311
+ def add_words(trial, close_gap_between_words=True):
312
+ chars_list_reconstructed = []
313
+ words_list = []
314
+ word_start_idx = 0
315
+ chars_df = pd.DataFrame(trial["chars_list"])
316
+ chars_df["char_width"] = chars_df.char_xmax - chars_df.char_xmin
317
+ space_width = chars_df.loc[chars_df["char"] == " ", "char_width"].mean()
318
+
319
+ for idx, char_dict in enumerate(trial["chars_list"]):
320
+ on_line_num = char_dict["assigned_line"]
321
+ chars_list_reconstructed.append(char_dict)
322
+ if (
323
+ char_dict["char"] in [" ", ",", ";", ".", ":"]
324
+ or (
325
+ len(chars_list_reconstructed) > 2
326
+ and (chars_list_reconstructed[-1]["char_xmin"] < chars_list_reconstructed[-2]["char_xmin"])
327
+ )
328
+ or len(chars_list_reconstructed) == len(trial["chars_list"])
329
+ ):
330
+ triggered = True
331
+ word_xmin = chars_list_reconstructed[word_start_idx]["char_xmin"]
332
+ word_xmax = chars_list_reconstructed[-2]["char_xmax"]
333
+ word_ymin = chars_list_reconstructed[word_start_idx]["char_ymin"]
334
+ word_ymax = chars_list_reconstructed[word_start_idx]["char_ymax"]
335
+ word_x_center = (word_xmax - word_xmin) / 2 + word_xmin
336
+ word_y_center = (word_ymax - word_ymin) / 2 + word_ymin
337
+ word = "".join(
338
+ [
339
+ chars_list_reconstructed[idx]["char"]
340
+ for idx in range(word_start_idx, len(chars_list_reconstructed) - 1)
341
+ ]
342
+ )
343
+ assigned_line = chars_list_reconstructed[word_start_idx]["assigned_line"]
344
+
345
+ word_dict = dict(
346
+ word=word,
347
+ word_xmin=word_xmin,
348
+ word_xmax=word_xmax,
349
+ word_ymin=word_ymin,
350
+ word_ymax=word_ymax,
351
+ word_x_center=word_x_center,
352
+ word_y_center=word_y_center,
353
+ assigned_line=assigned_line,
354
+ )
355
+ if char_dict["char"] != " ":
356
+ word_start_idx = idx
357
+ else:
358
+ word_start_idx = idx + 1
359
+ words_list.append(word_dict)
360
+ else:
361
+ triggered = False
362
+ last_letter_in_word = word_dict["word"][-1]
363
+ last_letter_in_chars_list_reconstructed = char_dict["char"]
364
+ if last_letter_in_word != last_letter_in_chars_list_reconstructed:
365
+ word_dict = dict(
366
+ word=char_dict["char"],
367
+ word_xmin=char_dict["char_xmin"],
368
+ word_xmax=char_dict["char_xmax"],
369
+ word_ymin=char_dict["char_ymin"],
370
+ word_ymax=char_dict["char_ymax"],
371
+ word_x_center=char_dict["char_x_center"],
372
+ word_y_center=char_dict["char_y_center"],
373
+ assigned_line=assigned_line,
374
+ )
375
+ words_list.append(word_dict)
376
+
377
+ if close_gap_between_words:
378
+ for widx in range(1, len(words_list)):
379
+ if words_list[widx]["assigned_line"] == words_list[widx - 1]["assigned_line"]:
380
+ word_sep_half_width = (words_list[widx]["word_xmin"] - words_list[widx - 1]["word_xmax"]) / 2
381
+ words_list[widx - 1]["word_xmax"] = words_list[widx - 1]["word_xmax"] + word_sep_half_width
382
+ words_list[widx]["word_xmin"] = words_list[widx]["word_xmin"] - word_sep_half_width
383
+
384
+ return words_list
385
+
386
+
387
+ def asc_lines_to_trials_by_trail_id(
388
+ lines: list, paragraph_trials_only=False, fname: str = "", close_gap_between_words=True
389
+ ) -> dict:
390
+ if hasattr(fname, "name"):
391
+ fname = fname.name
392
+ fps = -999
393
+ display_coords = -999
394
+ trials_dict = dict(paragraph_trials=[], paragraph_trial_IDs=[])
395
+ trial_idx = -1
396
+ removed_trial_ids = []
397
+ for idx, l in enumerate(lines):
398
+ parts = l.strip().split(" ")
399
+ if "TRIALID" in l:
400
+ trial_id = parts[-1]
401
+ trial_idx += 1
402
+ if trial_id[0] == "F":
403
+ trial_is = "question"
404
+ elif trial_id[0] == "P":
405
+ trial_is = "practice"
406
+ else:
407
+ trial_is = "paragraph"
408
+ trials_dict["paragraph_trials"].append(trial_idx)
409
+ trials_dict["paragraph_trial_IDs"].append(trial_id)
410
+ trials_dict[trial_idx] = dict(trial_id=trial_id, trial_id_idx=idx, trial_is=trial_is, filename=fname)
411
+ last_trial_skipped = False
412
+
413
+ elif "TRIAL_RESULT" in l or "stop_trial" in l:
414
+ trials_dict[trial_idx]["trial_result_idx"] = idx
415
+ trials_dict[trial_idx]["trial_result_timestamp"] = int(parts[0].split("\t")[1])
416
+ if len(parts) > 2:
417
+ trials_dict[trial_idx]["trial_result_number"] = int(parts[2])
418
+ elif "DISPLAY COORDS" in l and isinstance(display_coords, int):
419
+ display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
420
+ elif "GAZE_COORDS" in l and isinstance(display_coords, int):
421
+ display_coords = (float(parts[-4]), float(parts[-3]), float(parts[-2]), float(parts[-1]))
422
+ elif "FRAMERATE" in l:
423
+ l_idx = parts.index(metadata_strs[2])
424
+ fps = float(parts[l_idx + 1])
425
+ elif "TRIAL ABORTED" in l or "TRIAL REPEATED" in l:
426
+ if not last_trial_skipped:
427
+ if trial_is == "paragraph":
428
+ trials_dict["paragraph_trials"].remove(trial_idx)
429
+ trial_idx -= 1
430
+ removed_trial_ids.append(trial_id)
431
+ last_trial_skipped = True
432
+
433
+ if paragraph_trials_only:
434
+ trials_dict_temp = trials_dict.copy()
435
+ for k in trials_dict_temp.keys():
436
+ if k not in ["paragraph_trials"] + trials_dict_temp["paragraph_trials"]:
437
+ trials_dict.pop(k)
438
+ if len(trials_dict_temp["paragraph_trials"]):
439
+ trial_idx = trials_dict_temp["paragraph_trials"][-1]
440
+ else:
441
+ return trials_dict
442
+ trials_dict["display_coords"] = display_coords
443
+ trials_dict["fps"] = fps
444
+ trials_dict["max_trial_idx"] = trial_idx
445
+ enum = trials_dict["paragraph_trials"] if "paragraph_trials" in trials_dict.keys() else range(len(trials_dict))
446
+ for trial_idx in enum:
447
+ if trial_idx not in trials_dict.keys():
448
+ continue
449
+ chars_list = []
450
+ if "display_coords" not in trials_dict[trial_idx].keys():
451
+ trials_dict[trial_idx]["display_coords"] = trials_dict["display_coords"]
452
+ trial_start_idx = trials_dict[trial_idx]["trial_id_idx"]
453
+ trial_end_idx = trials_dict[trial_idx]["trial_result_idx"]
454
+ trial_lines = lines[trial_start_idx:trial_end_idx]
455
+ for idx, l in enumerate(trial_lines):
456
+ parts = l.strip().split(" ")
457
+ if "START" in l and " MSG" not in l:
458
+ trials_dict[trial_idx]["start_idx"] = trial_start_idx + idx + 7
459
+ trials_dict[trial_idx]["start_time"] = int(parts[0].split("\t")[1])
460
+ elif "END" in l and "ENDBUTTON" not in l and " MSG" not in l:
461
+ trials_dict[trial_idx]["end_idx"] = trial_start_idx + idx - 2
462
+ trials_dict[trial_idx]["end_time"] = int(parts[0].split("\t")[1])
463
+ elif "SYNCTIME" in l:
464
+ trials_dict[trial_idx]["synctime"] = trial_start_idx + idx
465
+ trials_dict[trial_idx]["synctime_time"] = int(parts[0].split("\t")[1])
466
+ elif "GAZE TARGET OFF" in l:
467
+ trials_dict[trial_idx]["gaze_targ_off_time"] = int(parts[0].split("\t")[1])
468
+ elif "GAZE TARGET ON" in l:
469
+ trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
470
+ elif "DISPLAY_SENTENCE" in l: # some .asc files seem to use this
471
+ trials_dict[trial_idx]["gaze_targ_on_time"] = int(parts[0].split("\t")[1])
472
+ elif "REGION CHAR" in l:
473
+ rg_idx = parts.index("CHAR")
474
+ if len(parts[rg_idx:]) > 8:
475
+ char = " "
476
+ idx_correction = 1
477
+ elif len(parts[rg_idx:]) == 3:
478
+ char = " "
479
+ if "REGION CHAR" not in trial_lines[idx + 1]:
480
+ parts = trial_lines[idx + 1].strip().split(" ")
481
+ idx_correction = -rg_idx - 4
482
+ else:
483
+ char = parts[rg_idx + 3]
484
+ idx_correction = 0
485
+ try:
486
+ char_dict = {
487
+ "char": char,
488
+ "char_xmin": float(parts[rg_idx + 4 + idx_correction]),
489
+ "char_ymin": float(parts[rg_idx + 5 + idx_correction]),
490
+ "char_xmax": float(parts[rg_idx + 6 + idx_correction]),
491
+ "char_ymax": float(parts[rg_idx + 7 + idx_correction]),
492
+ }
493
+ char_dict["char_y_center"] = (char_dict["char_ymax"] - char_dict["char_ymin"]) / 2 + char_dict[
494
+ "char_ymin"
495
+ ]
496
+ char_dict["char_x_center"] = (char_dict["char_xmax"] - char_dict["char_xmin"]) / 2 + char_dict[
497
+ "char_xmin"
498
+ ]
499
+ chars_list.append(char_dict)
500
+ except Exception as e:
501
+ if "logger" in st.session_state:
502
+ st.session_state["logger"].warning(f"char_dict creation failed for parts {parts}")
503
+ if "logger" in st.session_state:
504
+ st.session_state["logger"].warning(e)
505
+
506
+ if "gaze_targ_on_time" in trials_dict[trial_idx]:
507
+ trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["gaze_targ_on_time"]
508
+ else:
509
+ trials_dict[trial_idx]["trial_start_time"] = trials_dict[trial_idx]["start_time"]
510
+
511
+ if len(chars_list) > 0:
512
+ line_ycoords = []
513
+ for idx in range(len(chars_list)):
514
+ chars_list[idx]["char_line_y"] = (
515
+ chars_list[idx]["char_ymax"] - chars_list[idx]["char_ymin"]
516
+ ) / 2 + chars_list[idx]["char_ymin"]
517
+ if chars_list[idx]["char_line_y"] not in line_ycoords:
518
+ line_ycoords.append(chars_list[idx]["char_line_y"])
519
+ for idx in range(len(chars_list)):
520
+ chars_list[idx]["assigned_line"] = line_ycoords.index(chars_list[idx]["char_line_y"])
521
+
522
+ line_heights = [x["char_ymax"] - x["char_ymin"] for x in chars_list]
523
+ line_xcoords_all = [x["char_x_center"] for x in chars_list]
524
+ line_xcoords_no_pad = np.unique(line_xcoords_all)
525
+
526
+ line_ycoords_all = [x["char_y_center"] for x in chars_list]
527
+ line_ycoords_no_pad = np.unique(line_ycoords_all)
528
+
529
+ trials_dict[trial_idx]["x_char_unique"] = list(line_xcoords_no_pad)
530
+ trials_dict[trial_idx]["y_char_unique"] = list(line_ycoords_no_pad)
531
+ x_diff, y_diff = calc_xdiff_ydiff(
532
+ line_xcoords_no_pad, line_ycoords_no_pad, line_heights, allow_multiple_values=False
533
+ )
534
+ trials_dict[trial_idx]["x_diff"] = float(x_diff)
535
+ trials_dict[trial_idx]["y_diff"] = float(y_diff)
536
+ trials_dict[trial_idx]["num_char_lines"] = len(line_ycoords_no_pad)
537
+ trials_dict[trial_idx]["line_heights"] = line_heights
538
+ trials_dict[trial_idx]["chars_list"] = chars_list
539
+
540
+ words_list = add_words(trials_dict[trial_idx], close_gap_between_words=close_gap_between_words)
541
+ trials_dict[trial_idx]["words_list"] = words_list
542
+
543
+ return trials_dict
544
+
545
+
546
+ def file_to_trials_and_lines(uploaded_file, asc_encoding: str = "ISO-8859-15", close_gap_between_words=True):
547
+ if isinstance(uploaded_file, str) or isinstance(uploaded_file, pl.Path):
548
+ with open(uploaded_file, "r", encoding=asc_encoding) as f:
549
+ lines = f.readlines()
550
+ else:
551
+ stringio = StringIO(uploaded_file.getvalue().decode(asc_encoding))
552
+ loaded_str = stringio.read()
553
+ lines = loaded_str.split("\n")
554
+ trials_dict = asc_lines_to_trials_by_trail_id(
555
+ lines, True, uploaded_file, close_gap_between_words=close_gap_between_words
556
+ )
557
+
558
+ if "paragraph_trials" not in trials_dict.keys() and "trial_is" in trials_dict[0].keys():
559
+ paragraph_trials = []
560
+ for k in range(trials_dict["max_trial_idx"]):
561
+ if trials_dict[k]["trial_is"] == "paragraph":
562
+ paragraph_trials.append(k)
563
+ trials_dict["paragraph_trials"] = paragraph_trials
564
+
565
+ enum = (
566
+ trials_dict["paragraph_trials"]
567
+ if "paragraph_trials" in trials_dict.keys()
568
+ else range(trials_dict["max_trial_idx"])
569
+ )
570
+ for k in enum:
571
+ if "chars_list" in trials_dict[k].keys():
572
+ max_line = trials_dict[k]["chars_list"][-1]["assigned_line"]
573
+ words_on_lines = {x: [] for x in range(max_line + 1)}
574
+ [words_on_lines[x["assigned_line"]].append(x["char"]) for x in trials_dict[k]["chars_list"]]
575
+ sentence_list = ["".join([s for s in v]) for idx, v in words_on_lines.items()]
576
+ text = sentence_list[0] + "\n".join([x for x in sentence_list[1:]])
577
+ trials_dict[k]["sentence_list"] = sentence_list
578
+ trials_dict[k]["text"] = text
579
+ trials_dict[k]["max_line"] = max_line
580
+
581
+ return trials_dict, lines
582
+
583
+
584
+ def get_plot_props(trial, available_fonts):
585
+ if "font" in trial.keys():
586
+ font = trial["font"]
587
+ font_size = trial["font_size"]
588
+ if font not in available_fonts:
589
+ font = "DejaVu Sans Mono"
590
+ else:
591
+ font = "DejaVu Sans Mono"
592
+ font_size = 21
593
+ dpi = 100
594
+ if "display_coords" in trial.keys():
595
+ screen_res = (trial["display_coords"][2], trial["display_coords"][3])
596
+ else:
597
+ screen_res = (1920, 1080)
598
+ return font, font_size, dpi, screen_res
599
+
600
+
601
+ def trial_to_dfs(
602
+ trial: dict, lines: list, use_synctime: bool = False, save_lines_to_txt=False, cut_out_outer_fixations=False
603
+ ):
604
+ """trial should be dict of line numbers of trials.
605
+ lines should be list of lines from .asc file."""
606
+
607
+ if use_synctime and "synctime" in trial:
608
+ idx0, idxend = trial["synctime"] + 1, trial["trial_result_idx"]
609
+ else:
610
+ idx0, idxend = trial["start_idx"], trial["end_idx"]
611
+
612
+ line_dicts = []
613
+ fixations_dicts = []
614
+ blink_started = False
615
+
616
+ fixation_started = False
617
+ efix_count = 0
618
+ sfix_count = 0
619
+ sblink_count = 0
620
+
621
+ if save_lines_to_txt:
622
+ with open("Lines_plus500.txt", "w") as f:
623
+ f.writelines(lines[idx0 - 500 : idxend + 500])
624
+ eye_to_use = "R"
625
+ for l in lines[idx0 : idxend + 1]:
626
+ if "EFIX R" in l:
627
+ eye_to_use = "R"
628
+ break
629
+ elif "EFIX L" in l:
630
+ eye_to_use = "L"
631
+ break
632
+ for l in lines[idx0 : idxend + 1]:
633
+ parts = [x.strip() for x in l.split("\t")]
634
+ if f"EFIX {eye_to_use}" in l:
635
+ efix_count += 1
636
+ if fixation_started:
637
+ if parts[1] == "." and parts[2] == ".":
638
+ continue
639
+ fixations_dicts.append(
640
+ {
641
+ "start_time": float(parts[0].split()[-1].strip()),
642
+ "end_time": float(parts[1].strip()),
643
+ "duration": float(parts[2].strip()),
644
+ "x": float(parts[3].strip()),
645
+ "y": float(parts[4].strip()),
646
+ "pupil_size": float(parts[5].strip()),
647
+ }
648
+ )
649
+ if len(fixations_dicts) >= 2:
650
+ assert (
651
+ fixations_dicts[-1]["start_time"] > fixations_dicts[-2]["start_time"]
652
+ ), "start times not in order"
653
+ fixation_started = False
654
+
655
+ elif f"SFIX {eye_to_use}" in l:
656
+ sfix_count += 1
657
+ fixation_started = True
658
+ elif f"SBLINK {eye_to_use}" in l:
659
+ sblink_count += 1
660
+ blink_started = True
661
+ if not blink_started and not any([True for x in event_strs if x in l]):
662
+ if len(parts) < 3 or (parts[1] == "." and parts[2] == "."):
663
+ continue
664
+ line_dicts.append(
665
+ {
666
+ "idx": float(parts[0].strip()),
667
+ "x": float(parts[1].strip()),
668
+ "y": float(parts[2].strip()),
669
+ "p": float(parts[3].strip()),
670
+ }
671
+ )
672
+
673
+ elif f"EBLINK {eye_to_use}" in l:
674
+ blink_started = False
675
+
676
+ df = pd.DataFrame(line_dicts)
677
+ dffix = pd.DataFrame(fixations_dicts)
678
+ if len(fixations_dicts) > 0:
679
+ dffix["corrected_start_time"] = dffix.start_time - trial["trial_start_time"]
680
+ dffix["corrected_end_time"] = dffix.end_time - trial["trial_start_time"]
681
+ dffix["fix_duration"] = dffix.corrected_end_time.values - dffix.corrected_start_time.values
682
+ assert all(np.diff(dffix["corrected_start_time"]) > 0), "start times not in order"
683
+ else:
684
+ df, pd.DataFrame(), trial
685
+
686
+ if cut_out_outer_fixations:
687
+ dffix = dffix[(dffix.x > -10) & (dffix.y > -10) & (dffix.x < 1050) & (dffix.y < 800)]
688
+ trial["efix_count"] = efix_count
689
+ trial["eye_to_use"] = eye_to_use
690
+ trial["sfix_count"] = sfix_count
691
+ trial["sblink_count"] = sblink_count
692
+ return df, dffix, trial
693
+
694
+
695
+ def get_save_path(fpath, fname_ending):
696
+ save_path = gradio_plots.joinpath(f"{fpath.stem}_{fname_ending}.png")
697
+ return save_path
698
+
699
+
700
+ def save_im_load_convert(fpath, fig, fname_ending, mode):
701
+ save_path = get_save_path(fpath, fname_ending)
702
+ fig.savefig(save_path)
703
+ im = Image.open(save_path).convert(mode)
704
+ im.save(save_path)
705
+ return im
706
+
707
+
708
+ def get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, dffix=None, prefix="word"):
709
+ fig = plt.figure(figsize=(screen_res[0] / dpi, screen_res[1] / dpi), dpi=dpi)
710
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
711
+ ax.set_axis_off()
712
+ if dffix is not None:
713
+ ax.set_ylim((dffix.y.min(), dffix.y.max()))
714
+ ax.set_xlim((dffix.x.min(), dffix.x.max()))
715
+ else:
716
+ ax.set_ylim((words_df[f"{prefix}_y_center"].min() - y_margin, words_df[f"{prefix}_y_center"].max() + y_margin))
717
+ ax.set_xlim((words_df[f"{prefix}_x_center"].min() - x_margin, words_df[f"{prefix}_x_center"].max() + x_margin))
718
+ ax.invert_yaxis()
719
+ fig.add_axes(ax)
720
+ return fig, ax
721
+
722
+
723
+ def plot_text_boxes_fixations(
724
+ fpath,
725
+ dpi,
726
+ screen_res,
727
+ data_dir_sub,
728
+ set_font_size: bool,
729
+ font_size: int,
730
+ use_words: bool,
731
+ save_channel_repeats: bool,
732
+ save_combo_grey_and_rgb: bool,
733
+ dffix=None,
734
+ trial=None,
735
+ ):
736
+ if isinstance(fpath, str):
737
+ fpath = pl.Path(fpath)
738
+ if use_words:
739
+ prefix = "word"
740
+ else:
741
+ prefix = "char"
742
+ if dffix is None:
743
+ dffix = pd.read_csv(fpath)
744
+ if trial is None:
745
+ json_fpath = str(fpath).replace("_fixations.csv", "_trial.json")
746
+ with open(json_fpath, "r") as f:
747
+ trial = json.load(f)
748
+ words_df = pd.DataFrame(trial[f"{prefix}s_list"])
749
+ x_right = words_df[f"{prefix}_xmin"]
750
+ x_left = words_df[f"{prefix}_xmax"]
751
+ y_top = words_df[f"{prefix}_ymax"]
752
+ y_bottom = words_df[f"{prefix}_ymin"]
753
+
754
+ if f"{prefix}_x_center" not in words_df.columns:
755
+ words_df[f"{prefix}_x_center"] = (words_df[f"{prefix}_xmax"] - words_df[f"{prefix}_xmin"]) / 2 + words_df[
756
+ f"{prefix}_xmin"
757
+ ]
758
+ words_df[f"{prefix}_y_center"] = (words_df[f"{prefix}_ymax"] - words_df[f"{prefix}_ymin"]) / 2 + words_df[
759
+ f"{prefix}_ymin"
760
+ ]
761
+
762
+ x_margin = words_df[f"{prefix}_x_center"].mean() / 8
763
+ y_margin = words_df[f"{prefix}_y_center"].mean() / 4
764
+ times = dffix.corrected_start_time - dffix.corrected_start_time.min()
765
+ times = times / times.max()
766
+ times = np.linspace(0.25, 1, len(times))
767
+
768
+ if set_font_size:
769
+ font = "monospace"
770
+ else:
771
+ font_size = trial["font_size"] * 27 // dpi
772
+
773
+ font_props = FontProperties(family=font, style="normal", size=font_size)
774
+ if save_combo_grey_and_rgb:
775
+ fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
776
+ ax.scatter(dffix.x, dffix.y, alpha=times, facecolor="b")
777
+ for idx in range(len(x_left)):
778
+ xdiff = x_right[idx] - x_left[idx]
779
+ ydiff = y_top[idx] - y_bottom[idx]
780
+ rect = patches.Rectangle(
781
+ (x_left[idx] - 1, y_bottom[idx] - 1),
782
+ xdiff,
783
+ ydiff,
784
+ alpha=0.9,
785
+ linewidth=0.8,
786
+ edgecolor="r",
787
+ facecolor="none",
788
+ ) # seems to need one pixel offset
789
+ ax.text(
790
+ words_df[f"{prefix}_x_center"][idx],
791
+ words_df[f"{prefix}_y_center"][idx],
792
+ words_df[prefix][idx],
793
+ horizontalalignment="center",
794
+ verticalalignment="center",
795
+ fontproperties=font_props,
796
+ color="g",
797
+ )
798
+ ax.add_patch(rect)
799
+ fname_ending = f"{prefix}s_combo_rgb"
800
+ words_combo_rgb_im = save_im_load_convert(fpath, fig, fname_ending, "RGB")
801
+ plt.close("all")
802
+
803
+ fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
804
+
805
+ ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
806
+ for idx in range(len(x_left)):
807
+ xdiff = x_right[idx] - x_left[idx]
808
+ ydiff = y_top[idx] - y_bottom[idx]
809
+ rect = patches.Rectangle(
810
+ (x_left[idx] - 1, y_bottom[idx] - 1),
811
+ xdiff,
812
+ ydiff,
813
+ alpha=0.9,
814
+ linewidth=0.8,
815
+ edgecolor="k",
816
+ facecolor="none",
817
+ ) # seems to need one pixel offset
818
+ ax.text(
819
+ words_df[f"{prefix}_x_center"][idx],
820
+ words_df[f"{prefix}_y_center"][idx],
821
+ words_df[prefix][idx],
822
+ horizontalalignment="center",
823
+ verticalalignment="center",
824
+ fontproperties=font_props,
825
+ )
826
+ ax.add_patch(rect)
827
+ fname_ending = f"{prefix}s_combo_grey"
828
+ words_combo_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
829
+ plt.close("all")
830
+
831
+ fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
832
+
833
+ ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.01)
834
+ for idx in range(len(x_left)):
835
+ ax.text(
836
+ words_df[f"{prefix}_x_center"][idx],
837
+ words_df[f"{prefix}_y_center"][idx],
838
+ words_df[prefix][idx],
839
+ horizontalalignment="center",
840
+ verticalalignment="center",
841
+ fontproperties=font_props,
842
+ )
843
+ fname_ending = f"{prefix}s_grey"
844
+ words_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
845
+
846
+ plt.close("all")
847
+ fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
848
+
849
+ ax.scatter(words_df[f"{prefix}_x_center"], words_df[f"{prefix}_y_center"], s=1, facecolor="k", alpha=0.1)
850
+ for idx in range(len(x_left)):
851
+ xdiff = x_right[idx] - x_left[idx]
852
+ ydiff = y_top[idx] - y_bottom[idx]
853
+ rect = patches.Rectangle(
854
+ (x_left[idx] - 1, y_bottom[idx] - 1), xdiff, ydiff, alpha=0.9, linewidth=1, edgecolor="k", facecolor="grey"
855
+ ) # seems to need one pixel offset
856
+ ax.add_patch(rect)
857
+ fname_ending = f"{prefix}_boxes_grey"
858
+ word_boxes_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
859
+
860
+ plt.close("all")
861
+
862
+ fig, ax = get_fig_ax(screen_res, dpi, words_df, x_margin, y_margin, prefix=prefix)
863
+
864
+ ax.scatter(dffix.x, dffix.y, facecolor="k", alpha=times)
865
+ fname_ending = "fix_scatter_grey"
866
+ fix_scatter_grey_im = save_im_load_convert(fpath, fig, fname_ending, "L")
867
+
868
+ plt.close("all")
869
+
870
+ arr_combo = np.stack(
871
+ [
872
+ np.asarray(words_grey_im),
873
+ np.asarray(word_boxes_grey_im),
874
+ np.asarray(fix_scatter_grey_im),
875
+ ],
876
+ axis=2,
877
+ )
878
+
879
+ im_combo = Image.fromarray(arr_combo)
880
+ fname_ending = f"{prefix}s_channel_sep"
881
+
882
+ save_path = get_save_path(fpath, fname_ending)
883
+ print(f"save_path for im combo is {save_path}")
884
+ im_combo.save(fpath)
885
+
886
+ if save_channel_repeats:
887
+ arr_combo = np.stack([np.asarray(words_grey_im)] * 3, axis=2)
888
+ im_combo = Image.fromarray(arr_combo)
889
+ fname_ending = f"{prefix}s_channel_repeat"
890
+
891
+ save_path = get_save_path(fpath, fname_ending)
892
+ im_combo.save(save_path)
893
+
894
+ arr_combo = np.stack([np.asarray(word_boxes_grey_im)] * 3, axis=2)
895
+
896
+ im_combo = Image.fromarray(arr_combo)
897
+ fname_ending = f"{prefix}boxes_channel_repeat"
898
+
899
+ save_path = get_save_path(fpath, fname_ending)
900
+ im_combo.save(save_path)
901
+
902
+ arr_combo = np.stack([np.asarray(fix_scatter_grey_im)] * 3, axis=2)
903
+
904
+ im_combo = Image.fromarray(arr_combo)
905
+ fname_ending = "fix_channel_repeat"
906
+
907
+ save_path = get_save_path(fpath, fname_ending)
908
+ im_combo.save(save_path)
909
+
910
+
911
+ def add_line_overlaps_to_sample(trial, sample):
912
+ char_df = pd.DataFrame(trial["chars_list"])
913
+ line_overlaps = []
914
+ for arr in sample:
915
+ y_val = arr[1]
916
+ line_overlap = t.tensor(-1, dtype=t.float32)
917
+ for idx, (x1, x2) in enumerate(zip(char_df.char_ymin.unique(), char_df.char_ymax.unique())):
918
+ if x1 <= y_val <= x2:
919
+ line_overlap = t.tensor(idx, dtype=t.float32)
920
+ break
921
+ line_overlaps.append(line_overlap)
922
+ line_olaps_tensor = t.stack(line_overlaps, dim=0)
923
+ sample = t.cat([sample, line_olaps_tensor.unsqueeze(1)], dim=1)
924
+ return sample
925
+
926
+
927
+ def norm_coords_by_letter_min_x_y(
928
+ sample_idx: int,
929
+ trialslist: list,
930
+ samplelist: list,
931
+ chars_center_coords_list: list = None,
932
+ ):
933
+ chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
934
+ trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
935
+
936
+ min_x_chars = chars_df.char_xmin.min()
937
+ min_y_chars = chars_df.char_ymin.min()
938
+
939
+ norm_vector_substract = t.zeros(
940
+ (1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
941
+ )
942
+ norm_vector_substract[0, 0] = norm_vector_substract[0, 0] + 1 * min_x_chars
943
+ norm_vector_substract[0, 1] = norm_vector_substract[0, 1] + 1 * min_y_chars
944
+
945
+ samplelist[sample_idx] = samplelist[sample_idx] - norm_vector_substract
946
+
947
+ if chars_center_coords_list is not None:
948
+ norm_vector_substract = norm_vector_substract.squeeze(0)[:2]
949
+ if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_substract.shape[-1] * 2:
950
+ chars_center_coords_list[sample_idx][:, :2] -= norm_vector_substract
951
+ chars_center_coords_list[sample_idx][:, 2:] -= norm_vector_substract
952
+ else:
953
+ chars_center_coords_list[sample_idx] -= norm_vector_substract
954
+ return trialslist, samplelist, chars_center_coords_list
955
+
956
+
957
+ def norm_coords_by_letter_positions(
958
+ sample_idx: int,
959
+ trialslist: list,
960
+ samplelist: list,
961
+ meanlist: list = None,
962
+ stdlist: list = None,
963
+ return_mean_std_lists=False,
964
+ norm_by_char_averages=False,
965
+ chars_center_coords_list: list = None,
966
+ add_normalised_values_as_features=False,
967
+ ):
968
+ chars_df = pd.DataFrame(trialslist[sample_idx]["chars_list"])
969
+ trialslist[sample_idx]["x_char_unique"] = chars_df.char_xmin.unique()
970
+
971
+ min_x_chars = chars_df.char_xmin.min()
972
+ max_x_chars = chars_df.char_xmax.max()
973
+
974
+ norm_vector_multi = t.ones(
975
+ (1, samplelist[sample_idx].shape[1]), dtype=samplelist[sample_idx].dtype, device=samplelist[sample_idx].device
976
+ )
977
+ if norm_by_char_averages:
978
+ chars_list = trialslist[sample_idx]["chars_list"]
979
+ char_widths = np.asarray([x["char_xmax"] - x["char_xmin"] for x in chars_list])
980
+ char_heights = np.asarray([x["char_ymax"] - x["char_ymin"] for x in chars_list])
981
+ char_widths_average = np.mean(char_widths[char_widths > 0])
982
+ char_heights_average = np.mean(char_heights[char_heights > 0])
983
+
984
+ norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * char_widths_average
985
+ norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * char_heights_average
986
+
987
+ else:
988
+ line_height = min(np.unique(trialslist[sample_idx]["line_heights"]))
989
+ line_width = max_x_chars - min_x_chars
990
+ norm_vector_multi[0, 0] = norm_vector_multi[0, 0] * line_width
991
+ norm_vector_multi[0, 1] = norm_vector_multi[0, 1] * line_height
992
+ assert ~t.any(t.isnan(norm_vector_multi)), "Nan found in char norming vector"
993
+
994
+ norm_vector_multi = norm_vector_multi.squeeze(0)
995
+ if add_normalised_values_as_features:
996
+ norm_vector_multi = norm_vector_multi[norm_vector_multi != 1]
997
+ normed_features = samplelist[sample_idx][:, : norm_vector_multi.shape[0]] / norm_vector_multi
998
+ samplelist[sample_idx] = t.cat([samplelist[sample_idx], normed_features], dim=1)
999
+ else:
1000
+ samplelist[sample_idx] = samplelist[sample_idx] / norm_vector_multi # in case time or pupil size is included
1001
+ if chars_center_coords_list is not None:
1002
+ norm_vector_multi = norm_vector_multi[:2]
1003
+ if chars_center_coords_list[sample_idx].shape[-1] == norm_vector_multi.shape[-1] * 2:
1004
+ chars_center_coords_list[sample_idx][:, :2] /= norm_vector_multi
1005
+ chars_center_coords_list[sample_idx][:, 2:] /= norm_vector_multi
1006
+ else:
1007
+ chars_center_coords_list[sample_idx] /= norm_vector_multi
1008
+ if return_mean_std_lists:
1009
+ mean_val = samplelist[sample_idx].mean(axis=0).cpu().numpy()
1010
+ meanlist.append(mean_val)
1011
+ std_val = samplelist[sample_idx].std(axis=0).cpu().numpy()
1012
+ stdlist.append(std_val)
1013
+ assert ~any(np.isnan(mean_val)), "Nan found in mean_val"
1014
+ assert ~any(np.isnan(mean_val)), "Nan found in std_val"
1015
+
1016
+ return trialslist, samplelist, meanlist, stdlist, chars_center_coords_list
1017
+ return trialslist, samplelist, chars_center_coords_list
1018
+
1019
+
1020
+ def remove_compile_from_model(model):
1021
+ if hasattr(model.project, "_orig_mod"):
1022
+ model.project = model.project._orig_mod
1023
+ model.chars_conv = model.chars_conv._orig_mod
1024
+ model.chars_classifier = model.chars_classifier._orig_mod
1025
+ model.layer_norm_in = model.layer_norm_in._orig_mod
1026
+ model.bert_model = model.bert_model._orig_mod
1027
+ model.linear = model.linear._orig_mod
1028
+ else:
1029
+ print(f"remove_compile_from_model not done since model.project {model.project} has no orig_mod")
1030
+ return model
1031
+
1032
+
1033
+ def remove_compile_from_dict(state_dict):
1034
+ for key in list(state_dict.keys()):
1035
+ newkey = key.replace("._orig_mod.", ".")
1036
+ state_dict[newkey] = state_dict.pop(key)
1037
+ return state_dict
1038
+
1039
+
1040
+ def add_text_to_ax(
1041
+ chars_list,
1042
+ ax,
1043
+ font_to_use="DejaVu Sans Mono",
1044
+ fontsize=21,
1045
+ prefix="char",
1046
+ plot_boxes=True,
1047
+ plot_text=True,
1048
+ box_annotations=None,
1049
+ ):
1050
+ font_props = FontProperties(family=font_to_use, style="normal", size=fontsize)
1051
+ if not plot_boxes and not plot_text:
1052
+ return None
1053
+ if box_annotations is None:
1054
+ enum = chars_list
1055
+ else:
1056
+ enum = zip(chars_list, box_annotations)
1057
+ for v in enum:
1058
+ if box_annotations is not None:
1059
+ v, annot_text = v
1060
+ x0, y0 = v[f"{prefix}_xmin"], v[f"{prefix}_ymin"]
1061
+ xdiff, ydiff = v[f"{prefix}_xmax"] - v[f"{prefix}_xmin"], v[f"{prefix}_ymax"] - v[f"{prefix}_ymin"]
1062
+ if plot_text:
1063
+ ax.text(
1064
+ v[f"{prefix}_x_center"],
1065
+ v[f"{prefix}_y_center"],
1066
+ v[prefix],
1067
+ horizontalalignment="center",
1068
+ verticalalignment="center",
1069
+ fontproperties=font_props,
1070
+ )
1071
+ if plot_boxes:
1072
+ ax.add_patch(Rectangle((x0, y0), xdiff, ydiff, edgecolor="grey", facecolor="none", lw=0.8, alpha=0.4))
1073
+ if box_annotations is not None:
1074
+ ax.annotate(
1075
+ str(annot_text),
1076
+ (x0 + xdiff / 2, y0),
1077
+ horizontalalignment="center",
1078
+ verticalalignment="center",
1079
+ fontproperties=FontProperties(family=font_to_use, style="normal", size=fontsize / 1.5),
1080
+ )
1081
+
1082
+
1083
+ def plot_fixations_and_text(
1084
+ dffix: pd.DataFrame,
1085
+ trial: dict,
1086
+ plot_prefix="chars_",
1087
+ show=False,
1088
+ returnfig=False,
1089
+ save=False,
1090
+ savelocation="plot.png",
1091
+ font_to_use="DejaVu Sans Mono",
1092
+ fontsize=20,
1093
+ plot_classic=True,
1094
+ plot_boxes=True,
1095
+ plot_text=True,
1096
+ fig_size=(14, 8),
1097
+ dpi=300,
1098
+ turn_axis_on=True,
1099
+ algo_choice="slice",
1100
+ ):
1101
+ fig, ax = plt.subplots(1, 1, figsize=fig_size, tight_layout=True, dpi=dpi)
1102
+ if f"{plot_prefix}list" in trial.keys():
1103
+ add_text_to_ax(
1104
+ trial[f"{plot_prefix}list"],
1105
+ ax,
1106
+ font_to_use,
1107
+ fontsize=fontsize,
1108
+ prefix=plot_prefix[:-2],
1109
+ plot_boxes=plot_boxes,
1110
+ plot_text=plot_text,
1111
+ )
1112
+ ax.plot(dffix.x, dffix.y, "kX", label="Raw Fixations", alpha=0.9)
1113
+
1114
+ if plot_classic and f"line_num_{algo_choice}" in dffix.columns:
1115
+ ax.scatter(
1116
+ dffix.x,
1117
+ dffix[f"y_{algo_choice}"],
1118
+ marker="*",
1119
+ color="tab:green",
1120
+ label=f"{algo_choice} Prediction",
1121
+ alpha=0.9,
1122
+ )
1123
+ for x_before, y_before, x_after, y_after in zip(
1124
+ dffix.x.values, dffix[f"y_{algo_choice}"].values, dffix.x, dffix.y
1125
+ ):
1126
+ arr_delta_x = x_after - x_before
1127
+ arr_delta_y = y_after - y_before
1128
+ ax.arrow(x_before, y_before, arr_delta_x, arr_delta_y, color="tab:green", alpha=0.6)
1129
+ ax.set_ylabel("y (pixel)")
1130
+ ax.set_xlabel("x (pixel)")
1131
+
1132
+ ax.invert_yaxis()
1133
+ ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
1134
+ if not turn_axis_on:
1135
+ ax.axis("off")
1136
+ if save:
1137
+ plt.savefig(savelocation, dpi=dpi)
1138
+ if show:
1139
+ plt.show()
1140
+ if returnfig:
1141
+ return fig
1142
+ else:
1143
+ plt.close()
1144
+ return None
1145
+
1146
+
1147
+ def make_folders(gradio_temp_folder, gradio_temp_unzipped_folder, gradio_plots):
1148
+ gradio_temp_folder.mkdir(exist_ok=True)
1149
+ gradio_temp_unzipped_folder.mkdir(exist_ok=True)
1150
+ gradio_plots.mkdir(exist_ok=True)
1151
+ return 0
1152
+
1153
+
1154
+ def get_classic_cfg(fname):
1155
+ with open(fname, "r") as f:
1156
+ jsonsstring = f.read()
1157
+ classic_algos_cfg = json.loads(jsonsstring)
1158
+ classic_algos_cfg["slice"] = classic_algos_cfg["slice"]
1159
+ classic_algos_cfg = classic_algos_cfg
1160
+ return classic_algos_cfg
1161
+
1162
+
1163
+ def find_and_load_model(model_date="20240104-223349"):
1164
+ model_cfg_file = list(DIST_MODELS_FOLDER.glob(f"*{model_date}*.yaml"))
1165
+ if len(model_cfg_file) == 0:
1166
+ if "logger" in st.session_state:
1167
+ st.session_state["logger"].warning(f"No model cfg yaml found for {model_date}")
1168
+ return None, None
1169
+ model_cfg_file = model_cfg_file[0]
1170
+ with open(model_cfg_file) as f:
1171
+ model_cfg = yaml.safe_load(f)
1172
+
1173
+ model_cfg["system_type"] = "linux"
1174
+ model_file = list(pl.Path("models").glob(f"*{model_date}*.ckpt"))[0]
1175
+ model = load_model(model_file, model_cfg)
1176
+
1177
+ return model, model_cfg
1178
+
1179
+
1180
+ def load_model(model_file, cfg):
1181
+ try:
1182
+ model_loaded = t.load(model_file, map_location="cpu")
1183
+ if "hyper_parameters" in model_loaded.keys():
1184
+ model_cfg_temp = model_loaded["hyper_parameters"]["cfg"]
1185
+ else:
1186
+ model_cfg_temp = cfg
1187
+ model_state_dict = model_loaded["state_dict"]
1188
+ except Exception as e:
1189
+ if "logger" in st.session_state:
1190
+ st.session_state["logger"].warning(e)
1191
+ if "logger" in st.session_state:
1192
+ st.session_state["logger"].warning(f"Failed to load {model_file}")
1193
+ return None
1194
+ model = LitModel(
1195
+ [1, 500, 3],
1196
+ model_cfg_temp["hidden_dim_bert"],
1197
+ model_cfg_temp["num_attention_heads"],
1198
+ model_cfg_temp["n_layers_BERT"],
1199
+ model_cfg_temp["loss_function"],
1200
+ 1e-4,
1201
+ model_cfg_temp["weight_decay"],
1202
+ model_cfg_temp,
1203
+ model_cfg_temp["use_lr_warmup"],
1204
+ model_cfg_temp["use_reduce_on_plateau"],
1205
+ track_gradient_histogram=model_cfg_temp["track_gradient_histogram"],
1206
+ register_forw_hook=model_cfg_temp["track_activations_via_hook"],
1207
+ char_dims=model_cfg_temp["char_dims"],
1208
+ )
1209
+ model = remove_compile_from_model(model)
1210
+ model_state_dict = remove_compile_from_dict(model_state_dict)
1211
+ with t.no_grad():
1212
+ model.load_state_dict(model_state_dict, strict=False)
1213
+ model.eval()
1214
+ model.freeze()
1215
+ return model
1216
+
1217
+
1218
+ def set_up_models(dist_models_folder):
1219
+ out_dict = {}
1220
+ if "logger" in st.session_state:
1221
+ st.session_state["logger"].info("Loading Ensemble")
1222
+ dist_models_with_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_True*.ckpt"))
1223
+ dist_models_without_norm = list(dist_models_folder.glob("*normalize_by_line_height_and_width_False*.ckpt"))
1224
+ DIST_MODEL_DATE_WITH_NORM = dist_models_with_norm[0].stem.split("_")[1]
1225
+
1226
+ models_without_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm]
1227
+ models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
1228
+
1229
+ model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
1230
+ model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
1231
+
1232
+ models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
1233
+ models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
1234
+
1235
+ ensemble_model_avg = EnsembleModel(
1236
+ models_without_norm_df, models_with_norm_df, learning_rate=0.0058, use_simple_average=True
1237
+ )
1238
+ out_dict["ensemble_model_avg"] = ensemble_model_avg
1239
+
1240
+ out_dict["model_cfg_without_norm_df"] = model_cfg_without_norm_df
1241
+ out_dict["model_cfg_with_norm_df"] = model_cfg_with_norm_df
1242
+
1243
+ single_DIST_model, single_DIST_model_cfg = find_and_load_model(model_date=DIST_MODEL_DATE_WITH_NORM)
1244
+ out_dict["DIST_MODEL_DATE_WITH_NORM"] = DIST_MODEL_DATE_WITH_NORM
1245
+ out_dict["single_DIST_model"] = single_DIST_model
1246
+ out_dict["single_DIST_model_cfg"] = single_DIST_model_cfg
1247
+ return out_dict
1248
+
1249
+
1250
+ def prep_data_for_dist(model_cfg, dffix, trial=None):
1251
+ if "logger" in st.session_state:
1252
+ st.session_state["logger"].debug("prep_data_for_dist entered")
1253
+ if trial is None:
1254
+ trial = st.session_state["trial"]
1255
+ if isinstance(dffix, dict):
1256
+ dffix = dffix["value"]
1257
+ sample_tensor = t.tensor(dffix.loc[:, model_cfg["sample_cols"]].to_numpy(), dtype=t.float32)
1258
+
1259
+ if model_cfg["add_line_overlap_feature"]:
1260
+ sample_tensor = add_line_overlaps_to_sample(trial, sample_tensor)
1261
+
1262
+ has_nans = t.any(t.isnan(sample_tensor))
1263
+ assert not has_nans, "NaNs found in sample tensor"
1264
+ samplelist_eval = [sample_tensor]
1265
+ trialslist_eval = [trial]
1266
+ chars_center_coords_list_eval = None
1267
+ if model_cfg["norm_coords_by_letter_min_x_y"]:
1268
+ for sample_idx, _ in enumerate(samplelist_eval):
1269
+ trialslist_eval, samplelist_eval, chars_center_coords_list_eval = norm_coords_by_letter_min_x_y(
1270
+ sample_idx,
1271
+ trialslist_eval,
1272
+ samplelist_eval,
1273
+ chars_center_coords_list=chars_center_coords_list_eval,
1274
+ )
1275
+
1276
+ if model_cfg["normalize_by_line_height_and_width"]:
1277
+ meanlist_eval, stdlist_eval = [], []
1278
+ for sample_idx, _ in enumerate(samplelist_eval):
1279
+ (
1280
+ trialslist_eval,
1281
+ samplelist_eval,
1282
+ meanlist_eval,
1283
+ stdlist_eval,
1284
+ chars_center_coords_list_eval,
1285
+ ) = norm_coords_by_letter_positions(
1286
+ sample_idx,
1287
+ trialslist_eval,
1288
+ samplelist_eval,
1289
+ meanlist_eval,
1290
+ stdlist_eval,
1291
+ return_mean_std_lists=True,
1292
+ norm_by_char_averages=model_cfg["norm_by_char_averages"],
1293
+ chars_center_coords_list=chars_center_coords_list_eval,
1294
+ add_normalised_values_as_features=model_cfg["add_normalised_values_as_features"],
1295
+ )
1296
+ sample_tensor = samplelist_eval[0]
1297
+ sample_means = t.tensor(model_cfg["sample_means"], dtype=t.float32)
1298
+ sample_std = t.tensor(model_cfg["sample_std"], dtype=t.float32)
1299
+ sample_tensor = (sample_tensor - sample_means) / sample_std
1300
+ sample_tensor = sample_tensor.unsqueeze(0)
1301
+
1302
+ if "logger" in st.session_state:
1303
+ st.session_state["logger"].info(f"Using path {trial['plot_file']} for plotting")
1304
+ plot_text_boxes_fixations(
1305
+ fpath=trial["plot_file"],
1306
+ dpi=250,
1307
+ screen_res=(1024, 768),
1308
+ data_dir_sub=None,
1309
+ set_font_size=True,
1310
+ font_size=4,
1311
+ use_words=False,
1312
+ save_channel_repeats=False,
1313
+ save_combo_grey_and_rgb=False,
1314
+ dffix=dffix,
1315
+ trial=trial,
1316
+ )
1317
+
1318
+ val_set = DSet(
1319
+ sample_tensor,
1320
+ None,
1321
+ t.zeros((1, sample_tensor.shape[1])),
1322
+ trialslist_eval,
1323
+ padding_list=[0],
1324
+ padding_at_end=model_cfg["padding_at_end"],
1325
+ return_images_for_conv=True,
1326
+ im_partial_string=model_cfg["im_partial_string"],
1327
+ input_im_shape=model_cfg["char_plot_shape"],
1328
+ )
1329
+ val_loader = dl(val_set, batch_size=1, shuffle=False, num_workers=0)
1330
+ return val_loader, val_set
1331
+
1332
+
1333
+ def fold_in_seq_dim(out, y=None):
1334
+ batch_size, seq_len, num_classes = out.shape
1335
+
1336
+ out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
1337
+ if y is None:
1338
+ return out, None
1339
+ if len(y.shape) > 2:
1340
+ y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
1341
+ else:
1342
+ y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
1343
+ return out, y
1344
+
1345
+
1346
+ def logits_to_pred(out, y=None):
1347
+ seq_len = out.shape[1]
1348
+ out, y = fold_in_seq_dim(out, y)
1349
+ preds = corn_label_from_logits(out)
1350
+ preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
1351
+ if y is not None:
1352
+ y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
1353
+ y = y
1354
+ return preds, y
1355
+
1356
+
1357
+ def get_DIST_preds(dffix, trial, models_dict=None):
1358
+ algo_choice = "DIST"
1359
+
1360
+ if models_dict is None:
1361
+ if st.session_state["single_DIST_model"] is None or st.session_state["single_DIST_model_cfg"] is None:
1362
+ st.session_state["single_DIST_model"], st.session_state["single_DIST_model_cfg"] = find_and_load_model(
1363
+ model_date=st.session_state["DIST_MODEL_DATE_WITH_NORM"]
1364
+ )
1365
+
1366
+ if "logger" in st.session_state:
1367
+ st.session_state["logger"].info("Model is None, reiniting model")
1368
+ else:
1369
+ model = st.session_state["single_DIST_model"]
1370
+ loader, dset = prep_data_for_dist(st.session_state["single_DIST_model_cfg"], dffix, trial)
1371
+ else:
1372
+ model = models_dict["single_DIST_model"]
1373
+ loader, dset = prep_data_for_dist(models_dict["single_DIST_model_cfg"], dffix, trial)
1374
+ batch = next(iter(loader))
1375
+
1376
+ if "cpu" not in str(model.device):
1377
+ batch = [x.cuda() for x in batch]
1378
+ try:
1379
+ out = model(batch)
1380
+ preds, y = logits_to_pred(out, y=None)
1381
+ if "logger" in st.session_state:
1382
+ st.session_state["logger"].debug(
1383
+ f"y_char_unique are {trial['y_char_unique']} for trial {trial['trial_id']}"
1384
+ )
1385
+ if "logger" in st.session_state:
1386
+ st.session_state["logger"].debug(f"trial keys are {trial.keys()} for trial {trial['trial_id']}")
1387
+ if "logger" in st.session_state:
1388
+ st.session_state["logger"].debug(
1389
+ f"chars_list has len {len(trial['chars_list'])} for trial {trial['trial_id']}"
1390
+ )
1391
+ if "logger" in st.session_state:
1392
+ st.session_state["logger"].debug(f"y_char_unique {trial['y_char_unique']} for trial {trial['trial_id']}")
1393
+ if len(trial["y_char_unique"]) < 1:
1394
+ y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
1395
+ else:
1396
+ y_char_unique = trial["y_char_unique"]
1397
+ num_lines = trial["num_char_lines"] - 1
1398
+ preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
1399
+ y_pred_DIST = [y_char_unique[idx] for idx in preds]
1400
+
1401
+ dffix[f"line_num_{algo_choice}"] = preds
1402
+ dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
1403
+ dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
1404
+ except Exception as e:
1405
+ if "logger" in st.session_state:
1406
+ st.session_state["logger"].warning(f"Exception on model(batch) for DIST \n{e}")
1407
+ return dffix
1408
+
1409
+
1410
+ def get_DIST_ensemble_preds(
1411
+ dffix,
1412
+ trial,
1413
+ model_cfg_without_norm_df,
1414
+ model_cfg_with_norm_df,
1415
+ ensemble_model_avg,
1416
+ ):
1417
+ algo_choice = "DIST-Ensemble"
1418
+ loader_without_norm, dset_without_norm = prep_data_for_dist(model_cfg_without_norm_df, dffix, trial)
1419
+ loader_with_norm, dset_with_norm = prep_data_for_dist(model_cfg_with_norm_df, dffix, trial)
1420
+ batch_without_norm = next(iter(loader_without_norm))
1421
+ batch_with_norm = next(iter(loader_with_norm))
1422
+ out = ensemble_model_avg((batch_without_norm, batch_with_norm))
1423
+ preds, y = logits_to_pred(out[0]["out_avg"], y=None)
1424
+ if len(trial["y_char_unique"]) < 1:
1425
+ y_char_unique = pd.DataFrame(trial["chars_list"]).char_y_center.sort_values().unique()
1426
+ else:
1427
+ y_char_unique = trial["y_char_unique"]
1428
+ num_lines = trial["num_char_lines"] - 1
1429
+ preds = t.clamp(preds, 0, num_lines).squeeze().cpu().numpy()
1430
+ if "logger" in st.session_state:
1431
+ st.session_state["logger"].debug(f"preds are {preds} for trial {trial['trial_id']}")
1432
+ y_pred_DIST = [y_char_unique[idx] for idx in preds]
1433
+
1434
+ dffix[f"line_num_{algo_choice}"] = preds
1435
+ dffix[f"y_{algo_choice}"] = np.round(y_pred_DIST, decimals=1)
1436
+ dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
1437
+ return dffix
1438
+
1439
+
1440
+ def get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg=None, models_dict=None):
1441
+
1442
+ if models_dict is None:
1443
+ if ensemble_model_avg is None and "ensemble_model_avg" not in st.session_state:
1444
+ if "logger" in st.session_state:
1445
+ st.session_state["logger"].info("Ensemble Model is None, reiniting model")
1446
+ dist_models_with_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_True*.ckpt")
1447
+ dist_models_without_norm = DIST_MODELS_FOLDER.glob("*normalize_by_line_height_and_width_False*.ckpt")
1448
+
1449
+ models_without_norm_df = [
1450
+ find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_without_norm
1451
+ ]
1452
+ models_with_norm_df = [find_and_load_model(m_file.stem.split("_")[1]) for m_file in dist_models_with_norm]
1453
+
1454
+ model_cfg_without_norm_df = [x[1] for x in models_without_norm_df if x[1] is not None][0]
1455
+ model_cfg_with_norm_df = [x[1] for x in models_with_norm_df if x[1] is not None][0]
1456
+
1457
+ models_without_norm_df = [x[0] for x in models_without_norm_df if x[0] is not None]
1458
+ models_with_norm_df = [x[0] for x in models_with_norm_df if x[0] is not None]
1459
+
1460
+ ensemble_model_avg = EnsembleModel(
1461
+ models_without_norm_df, models_with_norm_df, learning_rate=0.0, use_simple_average=True
1462
+ )
1463
+ st.session_state["ensemble_model_avg"] = ensemble_model_avg
1464
+ st.session_state["model_cfg_without_norm_df"] = model_cfg_without_norm_df
1465
+ st.session_state["model_cfg_with_norm_df"] = model_cfg_with_norm_df
1466
+ else:
1467
+ model_cfg_without_norm_df = st.session_state["model_cfg_without_norm_df"]
1468
+ model_cfg_with_norm_df = st.session_state["model_cfg_with_norm_df"]
1469
+ ensemble_model_avg = st.session_state["ensemble_model_avg"]
1470
+ dffix = get_DIST_ensemble_preds(
1471
+ dffix,
1472
+ trial,
1473
+ st.session_state["model_cfg_without_norm_df"],
1474
+ st.session_state["model_cfg_with_norm_df"],
1475
+ st.session_state["ensemble_model_avg"],
1476
+ )
1477
+ else:
1478
+ dffix = get_DIST_ensemble_preds(
1479
+ dffix,
1480
+ trial,
1481
+ models_dict["model_cfg_without_norm_df"],
1482
+ models_dict["model_cfg_with_norm_df"],
1483
+ models_dict["ensemble_model_avg"],
1484
+ )
1485
+ return dffix
1486
+
1487
+
1488
+ def correct_df(
1489
+ dffix,
1490
+ algo_choice,
1491
+ trial=None,
1492
+ for_multi=False,
1493
+ ensemble_model_avg=None,
1494
+ is_outside_of_streamlit=False,
1495
+ classic_algos_cfg=None,
1496
+ models_dict=None,
1497
+ ):
1498
+ if is_outside_of_streamlit:
1499
+ stqdm = tqdm
1500
+ else:
1501
+ from stqdm import stqdm
1502
+ if classic_algos_cfg is None:
1503
+ classic_algos_cfg = st.session_state["classic_algos_cfg"]
1504
+ if trial is None and not for_multi:
1505
+ trial = st.session_state["trial"]
1506
+ if "logger" in st.session_state:
1507
+ st.session_state["logger"].info(f"Applying {algo_choice} to fixations for trial {trial['trial_id']}")
1508
+
1509
+ if isinstance(dffix, dict):
1510
+ dffix = dffix["value"]
1511
+ if "x" not in dffix.keys() or "x" not in dffix.keys():
1512
+ if "logger" in st.session_state:
1513
+ st.session_state["logger"].warning(f"x or y not in dffix")
1514
+ if "logger" in st.session_state:
1515
+ st.session_state["logger"].warning(dffix.columns)
1516
+ return dffix
1517
+ if isinstance(algo_choice, list):
1518
+ algo_choices = algo_choice
1519
+ repeats = range(len(algo_choice))
1520
+ else:
1521
+ algo_choices = [algo_choice]
1522
+ repeats = range(1)
1523
+ for algoIdx in stqdm(repeats, desc="Applying correction algorithms"):
1524
+ algo_choice = algo_choices[algoIdx]
1525
+ st_proc = time.process_time()
1526
+ st_wall = time.time()
1527
+
1528
+ if algo_choice == "DIST":
1529
+ dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
1530
+
1531
+ elif algo_choice == "DIST-Ensemble":
1532
+ dffix = get_EDIST_preds_with_model_check(dffix, trial, models_dict=models_dict)
1533
+ elif algo_choice == "Wisdom_of_Crowds_with_DIST":
1534
+ dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
1535
+ dffix = get_DIST_preds(dffix, trial, models_dict=models_dict)
1536
+ for _ in range(3):
1537
+ corrections.append(np.asarray(dffix.loc[:, "y_DIST"]))
1538
+ dffix = apply_woc(dffix, trial, corrections, algo_choice)
1539
+ elif algo_choice == "Wisdom_of_Crowds_with_DIST_Ensemble":
1540
+ dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
1541
+ dffix = get_EDIST_preds_with_model_check(dffix, trial, ensemble_model_avg, models_dict=models_dict)
1542
+ for _ in range(3):
1543
+ corrections.append(np.asarray(dffix.loc[:, "y_DIST-Ensemble"]))
1544
+ dffix = apply_woc(dffix, trial, corrections, algo_choice)
1545
+ elif algo_choice == "Wisdom_of_Crowds":
1546
+ dffix, corrections = get_all_classic_preds(dffix, trial, classic_algos_cfg)
1547
+ dffix = apply_woc(dffix, trial, corrections, algo_choice)
1548
+
1549
+ else:
1550
+ algo_cfg = classic_algos_cfg[algo_choice]
1551
+ dffix = calgo.apply_classic_algo(dffix, trial, algo_choice, algo_cfg)
1552
+ dffix[f"y_{algo_choice}_correction"] = (dffix.loc[:, f"y_{algo_choice}"] - dffix.loc[:, "y"]).round(1)
1553
+
1554
+ et_proc = time.process_time()
1555
+ time_proc = et_proc - st_proc
1556
+ et_wall = time.time()
1557
+ time_wall = et_wall - st_wall
1558
+ if "logger" in st.session_state:
1559
+ st.session_state["logger"].info(f"time_proc {algo_choice} {time_proc}")
1560
+ if "logger" in st.session_state:
1561
+ st.session_state["logger"].info(f"time_wall {algo_choice} {time_wall}")
1562
+ if for_multi:
1563
+ return dffix
1564
+ else:
1565
+ if "start_time" in dffix.columns:
1566
+ dffix = dffix.drop(axis=1, labels=["start_time", "end_time"])
1567
+ return dffix, export_csv(dffix, trial)
1568
+
1569
+ def set_font_from_chars_list(trial):
1570
+
1571
+ if "chars_list" in trial:
1572
+ chars_df = pd.DataFrame(trial["chars_list"])
1573
+ line_diffs = np.diff(chars_df.char_y_center.unique())
1574
+ y_diffs = np.unique(line_diffs)
1575
+ if len(y_diffs) == 1:
1576
+ y_diff = y_diffs[0]
1577
+ else:
1578
+ y_diff = np.min(y_diffs)
1579
+ y_diff = round(y_diff * 2) / 2
1580
+
1581
+ else:
1582
+ y_diff = 1 / 0.333 * 18
1583
+ font_size = y_diff * 0.333 # pixel to point conversion
1584
+ return round((font_size)*4,ndigits=0)/4
1585
+
1586
+ def get_font_and_font_size_from_trial(trial):
1587
+ font_face, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
1588
+
1589
+ if font_size is None and "font_size" in trial:
1590
+ font_size = trial["font_size"]
1591
+ elif font_size is None:
1592
+ font_size = set_font_from_chars_list(trial)
1593
+ return font_face, font_size
1594
+
1595
+
1596
+ def sigmoid(x):
1597
+ return 1 / (1 + np.exp(-1 * x))
1598
+
1599
+
1600
+ def matplotlib_plot_df(
1601
+ dffix,
1602
+ trial,
1603
+ algo_choice,
1604
+ stimulus_prefix="word",
1605
+ desired_dpi=300,
1606
+ fix_to_plot=[],
1607
+ stim_info_to_plot=["Words", "Word boxes"],
1608
+ box_annotations=None,
1609
+ ):
1610
+ chars_df = pd.DataFrame(trial["chars_list"]) if "chars_list" in trial else None
1611
+
1612
+ if chars_df is not None:
1613
+ font_face, font_size = get_font_and_font_size_from_trial(trial)
1614
+ font_size = font_size * 0.65
1615
+ else:
1616
+ st.warning("No character or word information available to plot")
1617
+
1618
+ if "display_coords" in trial:
1619
+ desired_width_in_pixels = trial["display_coords"][2] + 1
1620
+ desired_height_in_pixels = trial["display_coords"][3] + 1
1621
+ else:
1622
+ desired_width_in_pixels = 1920
1623
+ desired_height_in_pixels = 1080
1624
+
1625
+ figure_width = desired_width_in_pixels / desired_dpi
1626
+ figure_height = desired_height_in_pixels / desired_dpi
1627
+
1628
+ fig = plt.figure(figsize=(figure_width, figure_height), dpi=desired_dpi)
1629
+ ax = fig.add_subplot(1, 1, 1)
1630
+ fig.subplots_adjust(bottom=0)
1631
+ fig.subplots_adjust(top=1)
1632
+ fig.subplots_adjust(right=1)
1633
+ fig.subplots_adjust(left=0)
1634
+ if "font" in trial and trial["font"] in AVAILABLE_FONTS:
1635
+ font_to_use = trial["font"]
1636
+ else:
1637
+ font_to_use = "DejaVu Sans Mono"
1638
+ if "font_size" in trial:
1639
+ font_size = trial["font_size"]
1640
+ else:
1641
+ font_size = 20
1642
+
1643
+ if f"{stimulus_prefix}s_list" in trial:
1644
+ add_text_to_ax(
1645
+ trial[f"{stimulus_prefix}s_list"],
1646
+ ax,
1647
+ font_to_use,
1648
+ prefix=stimulus_prefix,
1649
+ fontsize=font_size / 3.89,
1650
+ plot_text=False,
1651
+ plot_boxes=True if "Word boxes" in stim_info_to_plot else False,
1652
+ box_annotations=box_annotations,
1653
+ )
1654
+
1655
+ if "chars_list" in trial:
1656
+ add_text_to_ax(
1657
+ trial["chars_list"],
1658
+ ax,
1659
+ font_to_use,
1660
+ prefix="char",
1661
+ fontsize=font_size / 3.89,
1662
+ plot_text=True if "Words" in stim_info_to_plot else False,
1663
+ plot_boxes=False,
1664
+ box_annotations=None,
1665
+ )
1666
+
1667
+ if "Uncorrected Fixations" in fix_to_plot:
1668
+ ax.plot(dffix.x, dffix.y, label="Raw fixations", color="blue", alpha=0.6, linewidth=0.6)
1669
+
1670
+ x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
1671
+ x1 = dffix.x.iloc[range(1, len(dffix.x))].values
1672
+ y0 = dffix.y.iloc[range(len(dffix.y) - 1)].values
1673
+ y1 = dffix.y.iloc[range(1, len(dffix.y))].values
1674
+ xpos = x0
1675
+ ypos = y0
1676
+ xdir = x1 - x0
1677
+ ydir = y1 - y0
1678
+ for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
1679
+ ax.annotate(
1680
+ "",
1681
+ xytext=(X, Y),
1682
+ xy=(X + 0.001 * dX, Y + 0.001 * dY),
1683
+ arrowprops=dict(arrowstyle="fancy", color="blue"),
1684
+ size=8,
1685
+ alpha=0.3,
1686
+ )
1687
+ if "Corrected Fixations" in fix_to_plot:
1688
+ if isinstance(algo_choice, list):
1689
+ algo_choices = algo_choice
1690
+ repeats = range(len(algo_choice))
1691
+ else:
1692
+ algo_choices = [algo_choice]
1693
+ repeats = range(1)
1694
+ for algoIdx in repeats:
1695
+ algo_choice = algo_choices[algoIdx]
1696
+ if f"y_{algo_choice}" in dffix.columns:
1697
+ ax.plot(
1698
+ dffix.x,
1699
+ dffix.loc[:, f"y_{algo_choice}"],
1700
+ label="Raw fixations",
1701
+ color=COLORS[algoIdx],
1702
+ alpha=0.6,
1703
+ linewidth=0.6,
1704
+ )
1705
+
1706
+ x0 = dffix.x.iloc[range(len(dffix.x) - 1)].values
1707
+ x1 = dffix.x.iloc[range(1, len(dffix.x))].values
1708
+ y0 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(len(dffix.loc[:, f"y_{algo_choice}"]) - 1)].values
1709
+ y1 = dffix.loc[:, f"y_{algo_choice}"].iloc[range(1, len(dffix.loc[:, f"y_{algo_choice}"]))].values
1710
+ xpos = x0
1711
+ ypos = y0
1712
+ xdir = x1 - x0
1713
+ ydir = y1 - y0
1714
+ for X, Y, dX, dY in zip(xpos, ypos, xdir, ydir):
1715
+ ax.annotate(
1716
+ "",
1717
+ xytext=(X, Y),
1718
+ xy=(X + 0.001 * dX, Y + 0.001 * dY),
1719
+ arrowprops=dict(arrowstyle="fancy", color=COLORS[algoIdx]),
1720
+ size=8,
1721
+ alpha=0.3,
1722
+ )
1723
+
1724
+ ax.set_xlim((0, desired_width_in_pixels))
1725
+ ax.set_ylim((0, desired_height_in_pixels))
1726
+ ax.invert_yaxis()
1727
+
1728
+ return fig, desired_width_in_pixels, desired_height_in_pixels
1729
+
1730
+
1731
+ def plotly_plot_with_image(
1732
+ dffix,
1733
+ trial,
1734
+ algo_choice,
1735
+ to_plot_list=["Uncorrected Fixations", "Words", "corrected fixations", "Word boxes"],
1736
+ scale_factor=0.5,
1737
+ ):
1738
+ fig, img_width, img_height = matplotlib_plot_df(
1739
+ dffix, trial, algo_choice, desired_dpi=300, fix_to_plot=[], stim_info_to_plot=to_plot_list
1740
+ )
1741
+ fig.savefig(TEMP_FIGURE_STIMULUS_PATH)
1742
+ fig = go.Figure()
1743
+ fig.add_trace(
1744
+ go.Scatter(
1745
+ x=[0, img_width * scale_factor],
1746
+ y=[img_height * scale_factor, 0],
1747
+ mode="markers",
1748
+ marker_opacity=0,
1749
+ name="scale_helper",
1750
+ )
1751
+ )
1752
+
1753
+ fig.update_xaxes(visible=False, range=[0, img_width * scale_factor])
1754
+
1755
+ fig.update_yaxes(
1756
+ visible=False,
1757
+ range=[img_height * scale_factor, 0],
1758
+ scaleanchor="x",
1759
+ )
1760
+ if "Words" in to_plot_list or "Word boxes" in to_plot_list:
1761
+ imsource = Image.open(str(TEMP_FIGURE_STIMULUS_PATH))
1762
+ fig.add_layout_image(
1763
+ dict(
1764
+ x=0,
1765
+ sizex=img_width * scale_factor,
1766
+ y=0,
1767
+ sizey=img_height * scale_factor,
1768
+ xref="x",
1769
+ yref="y",
1770
+ opacity=1.0,
1771
+ layer="below",
1772
+ sizing="stretch",
1773
+ source=imsource,
1774
+ )
1775
+ )
1776
+
1777
+ if "Uncorrected Fixations" in to_plot_list:
1778
+ duration_scaled = dffix.duration - dffix.duration.min()
1779
+ duration_scaled = ((duration_scaled / duration_scaled.max()) - 0.5) * 3
1780
+ duration = sigmoid(duration_scaled) * 50 * scale_factor
1781
+ fig.add_trace(
1782
+ go.Scatter(
1783
+ x=dffix.x * scale_factor,
1784
+ y=dffix.y * scale_factor,
1785
+ mode="markers+lines+text",
1786
+ name="Raw fixations",
1787
+ marker=dict(
1788
+ color=COLORS[-1],
1789
+ symbol="arrow",
1790
+ size=duration.values,
1791
+ angleref="previous",
1792
+ line=dict(color="black", width=duration.values / 10),
1793
+ ),
1794
+ line_width=2 * scale_factor,
1795
+ text=np.arange(len(dffix.x)),
1796
+ textposition="middle right",
1797
+ textfont=dict(
1798
+ family="sans serif",
1799
+ size=18 * scale_factor,
1800
+ ),
1801
+ hoverinfo="text+x+y",
1802
+ opacity=0.9,
1803
+ )
1804
+ )
1805
+
1806
+ if "Corrected Fixations" in to_plot_list:
1807
+ if isinstance(algo_choice, list):
1808
+ algo_choices = algo_choice
1809
+ repeats = range(len(algo_choice))
1810
+ else:
1811
+ algo_choices = [algo_choice]
1812
+ repeats = range(1)
1813
+ for algoIdx in repeats:
1814
+ algo_choice = algo_choices[algoIdx]
1815
+ if f"y_{algo_choice}" in dffix.columns:
1816
+ fig.add_trace(
1817
+ go.Scatter(
1818
+ x=dffix.x * scale_factor,
1819
+ y=dffix.loc[:, f"y_{algo_choice}"] * scale_factor,
1820
+ mode="markers",
1821
+ name=f"{algo_choice} corrected",
1822
+ marker_color=COLORS[algoIdx],
1823
+ marker_size=10 * scale_factor,
1824
+ hoverinfo="text+x+y",
1825
+ opacity=0.75,
1826
+ )
1827
+ )
1828
+
1829
+ fig.update_layout(
1830
+ plot_bgcolor=None,
1831
+ width=img_width * scale_factor,
1832
+ height=img_height * scale_factor,
1833
+ margin={"l": 0, "r": 0, "t": 0, "b": 0},
1834
+ legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.8),
1835
+ )
1836
+
1837
+ for trace in fig["data"]:
1838
+ if trace["name"] == "scale_helper":
1839
+ trace["showlegend"] = False
1840
+ return fig
1841
+
1842
+
1843
+ def plot_y_corr(dffix, algo_choice, margin=dict(t=40, l=10, r=10, b=1)):
1844
+ num_datapoints = len(dffix.x)
1845
+
1846
+ layout = dict(
1847
+ plot_bgcolor="white",
1848
+ autosize=True,
1849
+ margin=margin,
1850
+ xaxis=dict(
1851
+ title="Fixation Index",
1852
+ linecolor="black",
1853
+ range=[-1, num_datapoints + 1],
1854
+ showgrid=False,
1855
+ mirror="all",
1856
+ showline=True,
1857
+ ),
1858
+ yaxis=dict(
1859
+ title="y correction",
1860
+ side="left",
1861
+ linecolor="black",
1862
+ showgrid=False,
1863
+ mirror="all",
1864
+ showline=True,
1865
+ ),
1866
+ legend=dict(orientation="v", yanchor="middle", y=0.95, xanchor="left", x=1.05),
1867
+ )
1868
+ if isinstance(dffix, dict):
1869
+ dffix = dffix["value"]
1870
+ algo_string = algo_choice[0] if isinstance(algo_choice, list) else algo_choice
1871
+ if f"y_{algo_string}_correction" not in dffix.columns:
1872
+ st.session_state["logger"].warning("No correction column found in dataframe")
1873
+ return go.Figure(layout=layout)
1874
+ if isinstance(dffix, dict):
1875
+ dffix = dffix["value"]
1876
+
1877
+ fig = go.Figure(layout=layout)
1878
+
1879
+ if isinstance(algo_choice, list):
1880
+ algo_choices = algo_choice
1881
+ repeats = range(len(algo_choice))
1882
+ else:
1883
+ algo_choices = [algo_choice]
1884
+ repeats = range(1)
1885
+ for algoIdx in repeats:
1886
+ algo_choice = algo_choices[algoIdx]
1887
+ fig.add_trace(
1888
+ go.Scatter(
1889
+ x=np.arange(num_datapoints),
1890
+ y=dffix.loc[:, f"y_{algo_choice}_correction"],
1891
+ mode="markers",
1892
+ name=f"{algo_choice} y correction",
1893
+ marker_color=COLORS[algoIdx],
1894
+ marker_size=3,
1895
+ showlegend=True,
1896
+ )
1897
+ )
1898
+ fig.update_yaxes(zeroline=True, zerolinewidth=1, zerolinecolor="black")
1899
+
1900
+ return fig
1901
+
1902
+
1903
+ def download_example_ascs(EXAMPLES_FOLDER, EXAMPLES_ASC_ZIP_FILENAME, OSF_DOWNLAOD_LINK, EXAMPLES_FOLDER_PATH):
1904
+ if not os.path.isdir(EXAMPLES_FOLDER):
1905
+ os.mkdir(EXAMPLES_FOLDER)
1906
+
1907
+ if not os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
1908
+ download_url(OSF_DOWNLAOD_LINK, EXAMPLES_ASC_ZIP_FILENAME)
1909
+ # os.system(f'''wget -O {EXAMPLES_ASC_ZIP_FILENAME} -c --read-timeout=5 --tries=0 "{OSF_DOWNLAOD_LINK}"''')
1910
+
1911
+ if os.path.exists(EXAMPLES_ASC_ZIP_FILENAME):
1912
+ if EXAMPLES_FOLDER_PATH.exists():
1913
+ EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
1914
+ if len(EXAMPLE_ASC_FILES) != 4:
1915
+ try:
1916
+ with zipfile.ZipFile(EXAMPLES_ASC_ZIP_FILENAME, "r") as zip_ref:
1917
+ zip_ref.extractall(EXAMPLES_FOLDER)
1918
+ except Exception as e:
1919
+ st.session_state["logger"].warning(e)
1920
+ st.session_state["logger"].warning(f"Extracting {EXAMPLES_ASC_ZIP_FILENAME} failed")
1921
+
1922
+ EXAMPLE_ASC_FILES = [x for x in EXAMPLES_FOLDER_PATH.glob("*.asc")]
1923
+ return EXAMPLE_ASC_FILES
1924
+
1925
+
1926
+ def process_trial_choice_single_csv(trial, algo_choice, file=None):
1927
+ trial_id = trial["trial_id"]
1928
+ if "dffix" in trial:
1929
+ dffix = trial["dffix"]
1930
+ else:
1931
+ if file is None:
1932
+ file = st.session_state["single_csv_file"]
1933
+ trial["plot_file"] = str(PLOTS_FOLDER.joinpath(f"{file.name}_{trial_id}_2ndInput_chars_channel_sep.png"))
1934
+ trial["fname"] = str(file.name)
1935
+ dffix = trial["dffix"] = st.session_state["trials_by_ids_single_csv"][trial_id]["dffix"]
1936
+
1937
+ font, font_size, dpi, screen_res = get_plot_props(trial, AVAILABLE_FONTS)
1938
+ chars_df = pd.DataFrame(trial["chars_list"])
1939
+ trial["chars_df"] = chars_df.to_dict()
1940
+ trial["y_char_unique"] = list(chars_df.char_y_center.sort_values().unique())
1941
+ if algo_choice is not None:
1942
+ dffix, _ = correct_df(dffix, algo_choice, trial)
1943
+ return dffix, trial, dpi, screen_res, font, font_size
1944
+
1945
+
1946
+ def add_default_font_and_character_props_to_state(trial):
1947
+ chars_list = trial["chars_list"]
1948
+ chars_df = pd.DataFrame(trial["chars_list"])
1949
+ line_diffs = np.diff(chars_df.char_y_center.unique())
1950
+ y_diffs = np.unique(line_diffs)
1951
+ if len(y_diffs) == 1:
1952
+ y_diff = y_diffs[0]
1953
+ else:
1954
+ y_diff = np.min(y_diffs)
1955
+ y_diff = round(y_diff * 2) / 2
1956
+ x_txt_start = chars_list[0]["char_xmin"]
1957
+ y_txt_start = chars_list[0]["char_y_center"]
1958
+
1959
+ font_face, font_size = get_font_and_font_size_from_trial(trial)
1960
+
1961
+ line_height = y_diff
1962
+ return y_diff, x_txt_start, y_txt_start, font_face, font_size, line_height
1963
+
1964
+ def get_all_measures(trial, dffix, prefix, use_corrected_fixations=True, correction_algo="warp"):
1965
+ if use_corrected_fixations:
1966
+ dffix_copy = copy.deepcopy(dffix)
1967
+ dffix_copy["y"] = dffix_copy[f"y_{correction_algo}"]
1968
+ else:
1969
+ dffix_copy = dffix
1970
+ initial_landing_position_own_vals = anf.initial_landing_position_own(trial, dffix_copy, prefix).set_index(
1971
+ f"{prefix}_index"
1972
+ )
1973
+ second_pass_duration_own_vals = anf.second_pass_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
1974
+ number_of_fixations_own_vals = anf.number_of_fixations_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
1975
+ initial_fixation_duration_own_vals = anf.initial_fixation_duration_own(trial, dffix_copy, prefix).set_index(
1976
+ f"{prefix}_index"
1977
+ )
1978
+ first_of_many_duration_own_vals = anf.first_of_many_duration_own(trial, dffix_copy, prefix).set_index(
1979
+ f"{prefix}_index"
1980
+ )
1981
+ total_fixation_duration_own_vals = anf.total_fixation_duration_own(trial, dffix_copy, prefix).set_index(
1982
+ f"{prefix}_index"
1983
+ )
1984
+ gaze_duration_own_vals = anf.gaze_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
1985
+ go_past_duration_own_vals = anf.go_past_duration_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
1986
+ initial_landing_distance_own_vals = anf.initial_landing_distance_own(trial, dffix_copy, prefix).set_index(
1987
+ f"{prefix}_index"
1988
+ )
1989
+ landing_distances_own_vals = anf.landing_distances_own(trial, dffix_copy, prefix).set_index(f"{prefix}_index")
1990
+ number_of_regressions_in_own_vals = anf.number_of_regressions_in_own(trial, dffix_copy, prefix).set_index(
1991
+ f"{prefix}_index"
1992
+ )
1993
+ own_measure_df = pd.concat(
1994
+ [
1995
+ df.drop(prefix, axis=1)
1996
+ for df in [
1997
+ number_of_fixations_own_vals,
1998
+ initial_fixation_duration_own_vals,
1999
+ first_of_many_duration_own_vals,
2000
+ total_fixation_duration_own_vals,
2001
+ gaze_duration_own_vals,
2002
+ go_past_duration_own_vals,
2003
+ second_pass_duration_own_vals,
2004
+ initial_landing_position_own_vals,
2005
+ initial_landing_distance_own_vals,
2006
+ landing_distances_own_vals,
2007
+ number_of_regressions_in_own_vals,
2008
+ ]
2009
+ ],
2010
+ axis=1,
2011
+ )
2012
+ own_measure_df[prefix] = number_of_fixations_own_vals[prefix]
2013
+ first_column = own_measure_df.pop(prefix)
2014
+ own_measure_df.insert(0, prefix, first_column)
2015
+ own_measure_df.insert(0, f"{prefix}_num", np.arange((own_measure_df.shape[0])))
2016
+ return own_measure_df