Skip to content

Conversation

@jbarrow
Copy link
Contributor

@jbarrow jbarrow commented Dec 14, 2023

Currently this is a draft -- the generate function does not work and has dropped all caching. I'll get back to this tomorrow before/after work, but feel free to modify and make any necessary changes!

Phi-2 transformer model is really interesting, it required a few modifications:

  • NewGELUActivation function
  • ParallelBlock which uses a single LayerNorm and combines residuals, attention outputs, and ff outputs at the end
  • RoPE positional embeddings

Loading the model requires breaking the Wqkv matrices into a Wq, Wk, and Wv matrix (though it's possible to reimplement the attention without doing that, I suppose). Loading the model at all requires einops installed:

pip install einops

I've tested that the forward pass of the model lines up with the 🤗 implementation. But for generate, I need to put in the kv-caching (which might mean removing the TransformerDecoder implementation altogether, which would mean updating the convert.py script).

I think it's close to being there, but will require (a) some care to get generation working right, and (b) some care to get fast inference on a MacBook. But very excited to run a good model locally on a MacBook Air. 😄

@altaic altaic mentioned this pull request Dec 14, 2023
@awni
Copy link
Member

awni commented Dec 14, 2023

@jbarrow If you set you permissions, I can push some updates in a commit to your fork.

Basically I added the cache and cleaned up the code a bit (used our built-in "new" GELU)

I still am not having success getting reasonable outputs from the model even in fp32. Also I wasn't able to reproduce the final layer you are getting in the .txt files. Maybe you could share more details on how you got those...

Still seems like there is a bug somewhere..

@jbarrow
Copy link
Contributor Author

jbarrow commented Dec 14, 2023

I have "Allow edits by maintainers" checked -- is there any other permission I'm missing?

To get the outputs, I just ran the __call__ (rather than generate):

print(model(**tokens))

As for the "new" GELU, is it the gelu_approx in mlx.nn? I just tried swapping that in, and I get slightly different outputs from the original, but still comparable. Going to see if I can continue to debug. 😄

@awni
Copy link
Member

awni commented Dec 14, 2023

I have "Allow edits by maintainers" checked -- is there any other permission I'm missing?

That's what I meant. I must have been typing the wrong thing last night 😪

Hey I just pushed the commit, it does switch to Gelu(approx="precise") which is the same function up to numerics. Sorry if the commit disturbs your setup, let me know how the debugging is going! I can a look more in like a half hour.

@jbarrow
Copy link
Contributor Author

jbarrow commented Dec 14, 2023

So, I believe I've identified the source of the error, but the correction will require a bit more than I have time for this morning (will get back to it this evening). I believe the Rotary Embedding implementations are different between MLX and the Phi-2 repo. The Phi-2 implementation is here: https://huggingface.co/microsoft/phi-2/blob/main/modeling_phi.py#L171

I was looking at the weights from the attention heads in the first attention layer, and the differences pop up pre-/post-rotary embedding. The outputs below are the first 5 values at each of the 23 token positions for the input prompt.

-- BEFORE ROTARY EMBEDDING --

MLX: Attention Layer 0, Head 0, Query Values 0..5:

[[-0.8564 -2.682  -2.521   3.49    1.593 ]
 [-0.2073 -0.986  -1.157   1.926   1.606 ]
 [-2.932  -4.746  -3.71    4.227   2.584 ]
 [-0.3882 -3.2    -2.791   2.688   2.28  ]
 [-0.6265 -2.521  -3.512   4.066   2.36  ]
 [ 0.7383 -4.484  -3.68    4.203   2.75  ]
 [ 0.2036 -0.5347 -1.796   1.586   1.143 ]
 [ 1.985  -2.527  -2.01    4.617   2.102 ]
 [-0.2627 -0.568  -1.546   2.377   1.574 ]
 [-0.1577 -1.187  -1.31    2.607   1.235 ]
 [ 1.985  -2.527  -2.01    4.617   2.102 ]
 [-0.2627 -0.568  -1.546   2.377   1.574 ]
 [-0.3096 -1.409  -1.444   2.58    1.867 ]
 [-0.5576 -3.98   -1.962   2.695   2.016 ]
 [ 1.015  -1.193  -1.305   1.803   1.536 ]
 [ 0.3665 -2.875  -1.51    2.816   2.21  ]
 [ 0.3376 -1.208  -1.44    2.412   1.972 ]
 [-2.096  -4.07   -2.666   3.996   2.795 ]
 [ 0.8643 -1.999  -1.686   1.97    1.604 ]
 [ 0.3447 -2.65   -1.917   2.197   2.031 ]
 [ 1.985  -2.527  -2.01    4.617   2.102 ]
 [-0.2627 -0.568  -1.546   2.377   1.574 ]
 [-0.1577 -1.187  -1.31    2.607   1.235 ]]

HF: Attention Layer 0, Head 0, Query Values 0..5:

[[-0.856  -2.682  -2.521   3.49    1.593 ]
 [-0.2074 -0.9854 -1.157   1.925   1.606 ]
 [-2.934  -4.746  -3.71    4.227   2.584 ]
 [-0.3882 -3.2    -2.791   2.688   2.28  ]
 [-0.6265 -2.521  -3.512   4.066   2.36  ]
 [ 0.7393 -4.49   -3.68    4.203   2.75  ]
 [ 0.2037 -0.5347 -1.796   1.586   1.144 ]
 [ 1.983  -2.525  -2.01    4.617   2.102 ]
 [-0.2625 -0.567  -1.545   2.377   1.574 ]
 [-0.1578 -1.1875 -1.311   2.607   1.235 ]
 [ 1.983  -2.525  -2.01    4.617   2.102 ]
 [-0.2625 -0.567  -1.545   2.377   1.574 ]
 [-0.31   -1.409  -1.443   2.58    1.866 ]
 [-0.5576 -3.98   -1.963   2.697   2.016 ]
 [ 1.015  -1.193  -1.304   1.802   1.536 ]
 [ 0.3665 -2.873  -1.509   2.816   2.213 ]
 [ 0.3376 -1.208  -1.439   2.412   1.972 ]
 [-2.094  -4.07   -2.666   3.996   2.795 ]
 [ 0.8643 -1.999  -1.685   1.969   1.604 ]
 [ 0.3442 -2.65   -1.917   2.197   2.031 ]
 [ 1.983  -2.525  -2.01    4.617   2.102 ]
 [-0.2625 -0.567  -1.545   2.377   1.574 ]
 [-0.1578 -1.1875 -1.311   2.607   1.235 ]]

-- AFTER ROTARY EMBEDDING --

MLX: Attention Layer 0, Head 0, Query Values 0..5:

[[-0.8564  -2.682   -2.521    3.49     1.593  ]
 [ 0.718   -0.707   -2.186    0.524    1.181  ]
 [ 5.54    -0.6904  -4.16    -3.787   -0.1343 ]
 [ 0.836    3.113    0.1768  -3.871   -1.481  ]
 [-1.498    2.123    3.656   -3.936   -2.201  ]
 [-4.09    -1.98     5.586   -0.11523 -2.727  ]
 [ 0.04614 -0.57     1.488    1.877   -0.8906 ]
 [ 3.158   -0.6006   1.534    4.797   -2.129  ]
 [ 0.6     -0.1771  -1.714    2.26    -0.5825 ]
 [ 0.633    1.016   -2.834    0.6934   0.961  ]
 [-3.04     1.039   -4.414   -2.422    2.14   ]
 [-0.569    0.2603  -0.3008  -2.818    1.957  ]
 [-1.018   -1.023    1.71    -2.41     1.175  ]
 [ 1.167   -3.846    3.332   -0.1279  -3.393  ]
 [ 1.321    0.842    1.62     1.526   -2.125  ]
 [ 1.592    2.422    0.509    3.154   -2.191  ]
 [-0.671    1.059   -1.773    2.18    -0.293  ]
 [-3.338    3.137   -4.797    0.1953   1.09   ]
 [-0.9307  -1.969   -1.676   -1.978    1.661  ]
 [ 0.7383  -2.57     0.3018  -2.9      2.7    ]
 [ 3.12     0.7812   2.816   -4.176    2.178  ]
 [ 0.619    0.0913   2.834   -0.042    1.956  ]
 [ 0.1472   1.1875   2.285    1.814    0.4153 ]]

HF: Attention Layer 0, Head 0, Query Values 0..5:

[[-0.856   -2.682   -2.521    3.49     1.593  ]
 [ 0.5913  -0.793   -0.6685   2.035    1.59   ]
 [ 4.906   -3.598   -1.9375   4.543    2.436  ]
 [ 0.67    -0.661   -1.068    2.78     1.979  ]
 [-2.324    0.2937   1.132    4.168    2.354  ]
 [-4.27     4.168    2.84     3.617    2.467  ]
 [ 0.1172   0.574    1.363    2.162    1.314  ]
 [ 3.46     3.234    5.234    3.672    3.703  ]
 [ 0.6626   0.9077   2.41     1.902    3.158  ]
 [ 0.287    1.382    1.627    0.5327   1.244  ]
 [-3.291   -0.7373   1.905    1.313    3.873  ]
 [-0.6323  -0.4856   0.799    0.5605   3.275  ]
 [-1.244   -1.572    1.026   -1.9      0.2148 ]
 [ 0.917   -2.596   -1.608   -0.1299   0.0722 ]
 [ 1.44     0.888   -1.873   -0.9224   0.04068]
 [ 0.544    0.5786   0.3843  -2.37    -0.5903 ]
 [-0.61     1.149   -1.464   -2.025    0.135  ]
 [-4.1      4.04    -3.69    -3.98    -0.5938 ]
 [ 0.4224   1.35    -2.064   -2.045   -0.0759 ]
 [ 0.5938   0.1451  -2.652   -2.49    -0.207  ]
 [ 3.54     1.36    -1.8     -5.156    2.084  ]
 [ 0.672    0.142   -0.7485  -2.848    1.6875 ]
 [ 0.1548  -0.797   -0.225   -2.295   -0.2349 ]]

@jbarrow jbarrow marked this pull request as draft December 14, 2023 14:41
@jbarrow jbarrow marked this pull request as ready for review December 14, 2023 16:27
@awni awni changed the title phi-2 draft Phi-2 Dec 14, 2023
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really good, and works really well for me, merging!!

@awni awni merged commit 92efa32 into ml-explore:main Dec 14, 2023
@lostmygithubaccount
Copy link

how long did conversion of the weights take? it's going for slowly for me (keeps fluctuating but up to 1hr estimate)

@awni
Copy link
Member

awni commented Dec 14, 2023

Maybe what's slow is your download time? You might need a faster internet connection :)

The conversion itself should be fast once the model is downloaded (<<1 minute). But let me know if you run into trouble there.

@jbarrow
Copy link
Contributor Author

jbarrow commented Dec 14, 2023

That might be the download time? It’s a 5GB download from the huggibgface hub. Once the weights are cached, maybe 20s for conversion for me?

@lostmygithubaccount
Copy link

I have phi-2 downloaded from huggingface, then just running python convert.py pointing at the weights. weird

my internet is very slow, on Starlink and currently on a google meet call hammering my network 😂 it's getting there, not a big deal

@lostmygithubaccount
Copy link

lostmygithubaccount commented Dec 14, 2023

first safetensor partition took 10 minutes, everything else was fast:
image

it's working! thanks for getting this PR in :)

Blaizzy pushed a commit to Blaizzy/mlx-examples that referenced this pull request Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants