No support for num_beams >1?

#15
by midev0 - opened

I get IndexError when I use num_beams greater than 1 in model.generate() call. Otherwise, works great. I have narrowed it down to just that parameter.

IndexError Traceback (most recent call last)

in <cell line: 92>()
97 conversation_history.append({"role": "user", "content": user_input})
98
---> 99 response = generate_response(conversation_history)
100
101

5 frames

/usr/local/lib/python3.10/dist-packages/transformers/cache_utils.py in reorder_cache(self, beam_idx)
82 for layer_idx in range(len(self.key_cache)):
83 device = self.key_cache[layer_idx].device
---> 84 self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
85 device = self.value_cache[layer_idx].device
86 self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

IndexError: index out of range in self

Sign up or log in to comment