From 88e34a3d49e3ca441f5f10841a4b2b884c265e3f Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Thu, 17 Jul 2025 09:59:01 +0200 Subject: [PATCH] Fix passing of embedding pooling type Pooling type was passed as an enum value (an integer), which will result in an error during init of the server. This fix passes the value as a compatible string instead. --- .../java/de/kherud/llama/ModelParameters.java | 7 +-- .../de/kherud/llama/LlamaEmbeddingsTest.java | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 src/test/java/de/kherud/llama/LlamaEmbeddingsTest.java diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index 7999295..eb2a841 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -459,7 +459,10 @@ public ModelParameters setJsonSchema(String schema) { * Set pooling type for embeddings (default: model default if unspecified). */ public ModelParameters setPoolingType(PoolingType type) { - parameters.put("--pooling", type.getArgValue()); + if (type != PoolingType.UNSPECIFIED) { + // Don't set if unspecified, as it will use the model's default pooling type + parameters.put("--pooling", type.name().toLowerCase()); + } return this; } @@ -960,5 +963,3 @@ public ModelParameters enableJinja() { } } - - diff --git a/src/test/java/de/kherud/llama/LlamaEmbeddingsTest.java b/src/test/java/de/kherud/llama/LlamaEmbeddingsTest.java new file mode 100644 index 0000000..3a5a89f --- /dev/null +++ b/src/test/java/de/kherud/llama/LlamaEmbeddingsTest.java @@ -0,0 +1,43 @@ +package de.kherud.llama; + +import de.kherud.llama.args.PoolingType; +import org.junit.*; + +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; + +public class LlamaEmbeddingsTest { + + private static final String modelPath = "models/codellama-7b.Q2_K.gguf"; + private static LlamaModel model; + + @BeforeClass + public static void setup() { + // Print PID of the current process to attach with GDB + // Remember to set 'echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope' to attach. + RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); + System.out.println("PID: " + runtime.getName().split("@")[0]); + } + + @After + public void tearDownTest() { + if (model != null) { + model.close(); + } + } + + @Test + public void testEmbeddingTypes() { + for (PoolingType type : PoolingType.values()) { + System.out.println("Testing embedding with pooling type: " + type); + if (type == PoolingType.RANK) { + continue; // Only supported by reranking models + } + model = new LlamaModel(new ModelParameters().setModel(modelPath).setGpuLayers(99).enableEmbedding().setPoolingType(type)); + String text = "This is a test sentence for embedding."; + float[] embedding = model.embed(text); + Assert.assertEquals(4096, embedding.length); + model.close(); + } + } +}