Skip to content

Add recurrent gemma#856

Merged
awni merged 2 commits intomainfrom
recurrent_gemma
Jul 7, 2024
Merged

Add recurrent gemma#856
awni merged 2 commits intomainfrom
recurrent_gemma

Conversation

@awni
Copy link
Copy Markdown
Member

@awni awni commented Jun 28, 2024

A port of recurrent Gemma to spice things up a little. Runs pretty well on an M1 max.

Convert:

python -m mlx_lm.convert --hf-path google/recurrentgemma-2b-it -q

Run:

 python -m mlx_lm.generate --model mlx_model --prompt "Write a very long story about Einstein" --max-tokens 1000 --temp 0.0

Generates:

==========
Prompt: <bos><start_of_turn>user
Write a very long story about Einstein<end_of_turn>
<start_of_turn>model

Albert Einstein, ...
Einstein's life and work continue to shape the world we live in today. His theories of relativity have changed the way we understand the universe, and his ideas have inspired generations of scientists and thinkers. He was a man of contradictions, a brilliant mind with a mischievous twinkle in his eye, and a deep love for G-d. And he will always be remembered as one of the greatest minds the world has ever known.
==========
Prompt: 144.822 tokens-per-sec
Generation: 99.815 tokens-per-sec

@awni awni requested a review from angeloskath June 28, 2024 17:09
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice and refreshing!

# TODO consider using rotating buffer here
# especially for very long generations
def _update(x, v):
t = x.shape[2] - self.window_size
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that also have - v.shape[2] ie x.shape[2] - self.window_size - v.shape[2] such that the returned value is of size self.window_size at axis 2?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple things, the context size is actually window_size + 1. That's why there is no +1 in there. Also v.shape[2] is always 1 so it doesn't matter too much. But if v.shape[2] > 1 then my understanding is we don't want to truncate the cache which would drop some tokens that we are allowed to attend to, but rather rely on the attention mask. That said, I haven't tested this code with v.shape[2] > 1 and I think its probably broken elsewhere (likely in the conv cache).

@awni awni merged commit 20e221f into main Jul 7, 2024
@awni awni deleted the recurrent_gemma branch July 7, 2024 19:10
@lin72h
Copy link
Copy Markdown

lin72h commented Jul 8, 2024

Awesome work! Just curious does the newly released 9b recurrent gemma also work?
recurrentgemma-9b

@awni
Copy link
Copy Markdown
Member Author

awni commented Jul 8, 2024

Yes! It should work, but I don't think I tried it so it would be good to confirm.

@chahn
Copy link
Copy Markdown

chahn commented Jul 8, 2024

Thanks a lot @awni for all your work!

This...

python -m mlx_lm.convert --hf-path google/recurrentgemma-2b-it -q
python -m mlx_lm.generate --model mlx_model --prompt "Write a very long story about Einstein" --max-tokens 1000 --temp 0.0

... is working great and as expected.

Am I missing something using google/recurrentgemma-9b-it?
Unfortunately

python -m mlx_lm.convert --hf-path google/recurrentgemma-9b-it -q

gives me following error:

[INFO] Loading
Fetching 11 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 141092.80it/s]
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/convert.py", line 62, in <module>
    main()
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/convert.py", line 58, in main
    convert(**vars(args))
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/utils.py", line 661, in convert
    model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/utils.py", line 467, in fetch_from_hub
    model = load_model(model_path, lazy)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/utils.py", line 397, in load_model
    model_args = model_args_class.from_dict(config)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chahn/.venv/lib/python3.12/site-packages/mlx_lm/models/base.py", line 50, in from_dict
    return cls(
           ^^^^
TypeError: ModelArgs.__init__() missing 2 required positional arguments: 'embeddings_scale_by_sqrt_dim' and '_block_types'

@awni
Copy link
Copy Markdown
Member Author

awni commented Jul 8, 2024

Indeed you're not missing anything. The 9B config is slightly different so we'll need a patch for that.

@awni
Copy link
Copy Markdown
Member Author

awni commented Jul 8, 2024

Ok I added some fixes to #877

@lin72h
Copy link
Copy Markdown

lin72h commented Jul 8, 2024

Nice job! 🍻

@chahn
Copy link
Copy Markdown

chahn commented Jul 9, 2024

Thanks a lot for the fix @awni!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants