jpohhhh commited on
Commit
a17e571
1 Parent(s): a5e5b1d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -12
handler.py CHANGED
@@ -8,18 +8,21 @@ import time
8
  import os
9
  import torch
10
 
11
- def max_pooling(model_output):
12
- # Get dimensions
13
- _, Z, Y = model_output.shape
14
- # Initialize an empty list with length Y (384 in your case)
15
- output_array = [0] * Y
16
- # Loop over secondary arrays (Z)
17
- for i in range(Z):
18
- # Loop over values in innermost arrays (Y)
19
- for j in range(Y):
20
- # If value is greater than current max, update max
21
- if model_output[0][i][j] > output_array[j]:
22
- output_array[j] = model_output[0][i][j]
 
 
 
23
  return output_array
24
 
25
  #Mean Pooling - Take attention mask into account for correct averaging
 
8
  import os
9
  import torch
10
 
11
+ def max_pooling(model_output):
12
+ # Get dimensions
13
+ Z, Y = len(model_output[0]), len(model_output[0][0])
14
+
15
+ # Initialize an empty list with length Y (384 in your case)
16
+ output_array = [0] * Y
17
+
18
+ # Loop over secondary arrays (Z)
19
+ for i in range(Z):
20
+ # Loop over values in innermost arrays (Y)
21
+ for j in range(Y):
22
+ # If value is greater than current max, update max
23
+ if model_output[0][i][j] > output_array[j]:
24
+ output_array[j] = model_output[0][i][j]
25
+
26
  return output_array
27
 
28
  #Mean Pooling - Take attention mask into account for correct averaging