Conversation
angeloskath
left a comment
There was a problem hiding this comment.
This is very nice and refreshing!
| # TODO consider using rotating buffer here | ||
| # especially for very long generations | ||
| def _update(x, v): | ||
| t = x.shape[2] - self.window_size |
There was a problem hiding this comment.
Should that also have - v.shape[2] ie x.shape[2] - self.window_size - v.shape[2] such that the returned value is of size self.window_size at axis 2?
There was a problem hiding this comment.
A couple things, the context size is actually window_size + 1. That's why there is no +1 in there. Also v.shape[2] is always 1 so it doesn't matter too much. But if v.shape[2] > 1 then my understanding is we don't want to truncate the cache which would drop some tokens that we are allowed to attend to, but rather rely on the attention mask. That said, I haven't tested this code with v.shape[2] > 1 and I think its probably broken elsewhere (likely in the conv cache).
|
Awesome work! Just curious does the newly released 9b recurrent gemma also work? |
|
Yes! It should work, but I don't think I tried it so it would be good to confirm. |
|
Thanks a lot @awni for all your work! This... ... is working great and as expected. Am I missing something using gives me following error: |
|
Indeed you're not missing anything. The 9B config is slightly different so we'll need a patch for that. |
|
Ok I added some fixes to #877 |
|
Nice job! 🍻 |
|
Thanks a lot for the fix @awni! |
A port of recurrent Gemma to spice things up a little. Runs pretty well on an M1 max.
Convert:
Run:
Generates: