diff --git a/app/api/route.js b/app/api/route.js index 8f225dc..e8a7890 100644 --- a/app/api/route.js +++ b/app/api/route.js @@ -1,6 +1,8 @@ import Replicate from "replicate"; import { ReplicateStream, StreamingTextResponse } from "ai"; +export const runtime = "edge"; + const replicate = new Replicate({ auth: process.env.REPLICATE_API_TOKEN, }); @@ -11,7 +13,10 @@ if (!process.env.REPLICATE_API_TOKEN) { ); } -export const runtime = "edge"; +const VERSIONS = { + "yorickvp/llava-13b": "e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358", + "nateraw/salmonn": "ad1d3f9d2bd683628242b68d890bef7f7bd97f738a7c2ccbf1743a594c723d83", +}; export async function POST(req) { const params = await req.json(); @@ -19,8 +24,8 @@ export async function POST(req) { const response = params.image ? await runLlava(params) : params.audio - ? await runSalmonn(params) - : await runLlama(params); + ? await runSalmonn(params) + : await runLlama(params); // Convert the response into a friendly text-stream const stream = await ReplicateStream(response); @@ -29,16 +34,18 @@ export async function POST(req) { } async function runLlama({ + model, prompt, systemPrompt, maxTokens, temperature, topP, - version, }) { console.log("running llama"); - return await replicate.predictions.create({ - // IMPORTANT! You must enable streaming. + + const [owner, name] = model.split("/"); + + return await replicate.models.predictions.create(owner, name, { stream: true, input: { prompt: `${prompt}`, @@ -48,8 +55,6 @@ async function runLlama({ repetition_penalty: 1, top_p: topP, }, - // IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming - version: version, }); } @@ -57,7 +62,6 @@ async function runLlava({ prompt, maxTokens, temperature, topP, image }) { console.log("running llava"); return await replicate.predictions.create({ - // IMPORTANT! You must enable streaming. stream: true, input: { prompt: `${prompt}`, @@ -66,8 +70,7 @@ async function runLlava({ prompt, maxTokens, temperature, topP, image }) { max_tokens: maxTokens, image: image, }, - // IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming - version: "6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc", // hardcoded https://replicate.com/yorickvp/llava-13b/versions + version: models["yorickvp/llava-13b"] }); } @@ -75,7 +78,6 @@ async function runSalmonn({ prompt, maxTokens, temperature, topP, audio }) { console.log("running salmonn"); return await replicate.predictions.create({ - // IMPORTANT! You must enable streaming. stream: true, input: { prompt: `${prompt}`, @@ -84,7 +86,6 @@ async function runSalmonn({ prompt, maxTokens, temperature, topP, audio }) { max_length: maxTokens, wav_path: audio, }, - // IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming - version: "ad1d3f9d2bd683628242b68d890bef7f7bd97f738a7c2ccbf1743a594c723d83", // hardcoded https://replicate.com/yorickvp/llava-13b/versions + version: models["nateraw/salmonn"] }); } diff --git a/app/components/ChatForm.js b/app/components/ChatForm.js index 3b13ea8..1658dbe 100644 --- a/app/components/ChatForm.js +++ b/app/components/ChatForm.js @@ -1,6 +1,6 @@ import { Uploader } from "uploader"; import { UploadButton } from "react-uploader"; - +import Metrics from "./Metrics"; const uploader = Uploader({ apiKey: "public_kW15biSARCJN7FAz6rANdRg3pNkh", }); @@ -38,7 +38,7 @@ const options = { }, }; -const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload }) => { +const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload, metrics, completion }) => { const handleSubmit = async (event) => { event.preventDefault(); onSubmit(prompt); @@ -56,6 +56,12 @@ const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload }) => { return (