diff --git a/midscene-core/src/main/java/com/midscene/core/agent/Agent.java b/midscene-core/src/main/java/com/midscene/core/agent/Agent.java index 6777ed1e..c0c5844e 100644 --- a/midscene-core/src/main/java/com/midscene/core/agent/Agent.java +++ b/midscene-core/src/main/java/com/midscene/core/agent/Agent.java @@ -10,6 +10,7 @@ import com.midscene.core.model.MistralModel; import com.midscene.core.model.OllamaModel; import com.midscene.core.model.OpenAIModel; +import com.midscene.core.model.QwenModel; import com.midscene.core.pojo.options.InputOptions; import com.midscene.core.pojo.options.LocateOptions; import com.midscene.core.pojo.options.ScrollOptions; @@ -30,15 +31,17 @@ public class Agent { private TaskCache cache; public Agent(PageDriver driver, AIModel aiModel) { - this.driver = driver; - this.cache = TaskCache.disabled(); - this.orchestrator = new Orchestrator(driver, aiModel, this.cache); + this(driver, aiModel, TaskCache.disabled(), 3); } public Agent(PageDriver driver, AIModel aiModel, TaskCache cache) { + this(driver, aiModel, cache, 3); + } + + public Agent(PageDriver driver, AIModel aiModel, TaskCache cache, int maxRetries) { this.driver = driver; this.cache = cache != null ? cache : TaskCache.disabled(); - this.orchestrator = new Orchestrator(driver, aiModel, this.cache); + this.orchestrator = new Orchestrator(driver, aiModel, this.cache, maxRetries); } /** @@ -56,9 +59,10 @@ public static Agent create(MidsceneConfig config, PageDriver driver) { case MISTRAL -> new MistralModel(config.getApiKey(), config.getModelName(), config.getBaseUrl()); case AZURE_OPEN_AI -> new AzureOpenAiModel(config.getApiKey(), config.getBaseUrl()); case OLLAMA -> new OllamaModel(config.getBaseUrl(), config.getModelName()); + case QWEN, THOUSAND_QUESTIONS -> new QwenModel(config.getApiKey(), config.getModelName(), config.getBaseUrl()); }; - return new Agent(driver, model); + return new Agent(driver, model, TaskCache.disabled(), config.getMaxRetries()); } /** @@ -77,9 +81,10 @@ public static Agent create(MidsceneConfig config, PageDriver driver, TaskCache c case MISTRAL -> new MistralModel(config.getApiKey(), config.getModelName(), config.getBaseUrl()); case AZURE_OPEN_AI -> new AzureOpenAiModel(config.getApiKey(), config.getBaseUrl()); case OLLAMA -> new OllamaModel(config.getBaseUrl(), config.getModelName()); + case QWEN, THOUSAND_QUESTIONS -> new QwenModel(config.getApiKey(), config.getModelName(), config.getBaseUrl()); }; - return new Agent(driver, model, cache); + return new Agent(driver, model, cache, config.getMaxRetries()); } /** diff --git a/midscene-core/src/main/java/com/midscene/core/agent/Orchestrator.java b/midscene-core/src/main/java/com/midscene/core/agent/Orchestrator.java index 44cc7736..e7603c88 100644 --- a/midscene-core/src/main/java/com/midscene/core/agent/Orchestrator.java +++ b/midscene-core/src/main/java/com/midscene/core/agent/Orchestrator.java @@ -20,15 +20,20 @@ public class Orchestrator { private final PageDriver driver; private final Planner planner; private final Executor executor; + private final int maxRetries; @Getter private final Context context; public Orchestrator(PageDriver driver, AIModel aiModel) { - this(driver, new Planner(aiModel, TaskCache.disabled()), new Executor(driver)); + this(driver, new Planner(aiModel, TaskCache.disabled()), new Executor(driver), 3); } public Orchestrator(PageDriver driver, AIModel aiModel, TaskCache cache) { - this(driver, new Planner(aiModel, cache), new Executor(driver)); + this(driver, new Planner(aiModel, cache), new Executor(driver), 3); + } + + public Orchestrator(PageDriver driver, AIModel aiModel, TaskCache cache, int maxRetries) { + this(driver, new Planner(aiModel, cache), new Executor(driver), maxRetries); } /** @@ -39,9 +44,14 @@ public Orchestrator(PageDriver driver, AIModel aiModel, TaskCache cache) { * @param executor The executor */ protected Orchestrator(PageDriver driver, Planner planner, Executor executor) { + this(driver, planner, executor, 3); + } + + protected Orchestrator(PageDriver driver, Planner planner, Executor executor, int maxRetries) { this.driver = driver; this.planner = planner; this.executor = executor; + this.maxRetries = maxRetries; this.context = new Context(); } @@ -75,7 +85,6 @@ public void execute(String instruction) { context.logInstruction(instruction); List history = new ArrayList<>(); - int maxRetries = 3; boolean finished = false; boolean cacheInvalidated = false; diff --git a/midscene-core/src/main/java/com/midscene/core/config/MidsceneConfig.java b/midscene-core/src/main/java/com/midscene/core/config/MidsceneConfig.java index 0d1e76bf..55b16f5d 100644 --- a/midscene-core/src/main/java/com/midscene/core/config/MidsceneConfig.java +++ b/midscene-core/src/main/java/com/midscene/core/config/MidsceneConfig.java @@ -7,6 +7,7 @@ public class MidsceneConfig { private final String modelName; private final String baseUrl; private final long timeoutMs; + private final int maxRetries; private MidsceneConfig(Builder builder) { this.provider = builder.provider; @@ -14,6 +15,7 @@ private MidsceneConfig(Builder builder) { this.modelName = builder.modelName; this.baseUrl = builder.baseUrl; this.timeoutMs = builder.timeoutMs; + this.maxRetries = builder.maxRetries; } public static Builder builder() { @@ -40,6 +42,10 @@ public String getBaseUrl() { return baseUrl; } + public int getMaxRetries() { + return maxRetries; + } + public static class Builder { private ModelProvider provider = ModelProvider.OPENAI; @@ -47,6 +53,7 @@ public static class Builder { private String modelName; private String baseUrl; private long timeoutMs = 30000; // Default 30s + private int maxRetries = 3; public Builder provider(ModelProvider provider) { this.provider = provider; @@ -73,6 +80,11 @@ public Builder baseUrl(String baseUrl) { return this; } + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + public MidsceneConfig build() { if (apiKey == null || apiKey.isEmpty()) { throw new IllegalArgumentException("API Key must be provided"); diff --git a/midscene-core/src/main/java/com/midscene/core/config/ModelProvider.java b/midscene-core/src/main/java/com/midscene/core/config/ModelProvider.java index f2836670..7602c0ef 100644 --- a/midscene-core/src/main/java/com/midscene/core/config/ModelProvider.java +++ b/midscene-core/src/main/java/com/midscene/core/config/ModelProvider.java @@ -6,7 +6,9 @@ public enum ModelProvider { ANTHROPIC("claude-3-5-sonnet-20240620", "https://api.anthropic.com/v1/"), MISTRAL("small-latest", "https://api.mistral.ai/v1"), AZURE_OPEN_AI("gpt-4o", "https://openai.azure.com/"), - OLLAMA("llama3.1", "http://localhost:11434/"); + OLLAMA("llama3.1", "http://localhost:11434/"), + QWEN("qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + THOUSAND_QUESTIONS("qwen-plus", "https://dashscope.aliyuncs.com/compatible-mode/v1"); private final String modelName; private final String baseUrl; diff --git a/midscene-core/src/main/java/com/midscene/core/model/QwenModel.java b/midscene-core/src/main/java/com/midscene/core/model/QwenModel.java new file mode 100644 index 00000000..c6898dac --- /dev/null +++ b/midscene-core/src/main/java/com/midscene/core/model/QwenModel.java @@ -0,0 +1,25 @@ +package com.midscene.core.model; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.openai.OpenAiChatModel; +import java.util.List; + +public class QwenModel implements AIModel { + + private final ChatModel model; + + public QwenModel(String apiKey, String modelName, String baseUrl) { + this.model = OpenAiChatModel.builder() + .apiKey(apiKey) + .modelName(modelName) + .baseUrl(baseUrl) + .build(); + } + + @Override + public ChatResponse chat(List messages) { + return model.chat(messages); + } +}