Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions app/api/route.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import Replicate from "replicate";
import { ReplicateStream, StreamingTextResponse } from "ai";

export const runtime = "edge";

const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN,
});
Expand All @@ -11,16 +13,19 @@ if (!process.env.REPLICATE_API_TOKEN) {
);
}

export const runtime = "edge";
const VERSIONS = {
"yorickvp/llava-13b": "e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358",
"nateraw/salmonn": "ad1d3f9d2bd683628242b68d890bef7f7bd97f738a7c2ccbf1743a594c723d83",
};

export async function POST(req) {
const params = await req.json();

const response = params.image
? await runLlava(params)
: params.audio
? await runSalmonn(params)
: await runLlama(params);
? await runSalmonn(params)
: await runLlama(params);

// Convert the response into a friendly text-stream
const stream = await ReplicateStream(response);
Expand All @@ -29,16 +34,18 @@ export async function POST(req) {
}

async function runLlama({
model,
prompt,
systemPrompt,
maxTokens,
temperature,
topP,
version,
}) {
console.log("running llama");
return await replicate.predictions.create({
// IMPORTANT! You must enable streaming.

const [owner, name] = model.split("/");

return await replicate.models.predictions.create(owner, name, {
stream: true,
input: {
prompt: `${prompt}`,
Expand All @@ -48,16 +55,13 @@ async function runLlama({
repetition_penalty: 1,
top_p: topP,
},
// IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming
version: version,
});
}

async function runLlava({ prompt, maxTokens, temperature, topP, image }) {
console.log("running llava");

return await replicate.predictions.create({
// IMPORTANT! You must enable streaming.
stream: true,
input: {
prompt: `${prompt}`,
Expand All @@ -66,16 +70,14 @@ async function runLlava({ prompt, maxTokens, temperature, topP, image }) {
max_tokens: maxTokens,
image: image,
},
// IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming
version: "6bc1c7bb0d2a34e413301fee8f7cc728d2d4e75bfab186aa995f63292bda92fc", // hardcoded https://replicate.com/yorickvp/llava-13b/versions
version: models["yorickvp/llava-13b"]
});
}

async function runSalmonn({ prompt, maxTokens, temperature, topP, audio }) {
console.log("running salmonn");

return await replicate.predictions.create({
// IMPORTANT! You must enable streaming.
stream: true,
input: {
prompt: `${prompt}`,
Expand All @@ -84,7 +86,6 @@ async function runSalmonn({ prompt, maxTokens, temperature, topP, audio }) {
max_length: maxTokens,
wav_path: audio,
},
// IMPORTANT! The model must support streaming. See https://replicate.com/docs/streaming
version: "ad1d3f9d2bd683628242b68d890bef7f7bd97f738a7c2ccbf1743a594c723d83", // hardcoded https://replicate.com/yorickvp/llava-13b/versions
version: models["nateraw/salmonn"]
});
}
10 changes: 8 additions & 2 deletions app/components/ChatForm.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Uploader } from "uploader";
import { UploadButton } from "react-uploader";

import Metrics from "./Metrics";
const uploader = Uploader({
apiKey: "public_kW15biSARCJN7FAz6rANdRg3pNkh",
});
Expand Down Expand Up @@ -38,7 +38,7 @@ const options = {
},
};

const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload }) => {
const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload, metrics, completion }) => {
const handleSubmit = async (event) => {
event.preventDefault();
onSubmit(prompt);
Expand All @@ -56,6 +56,12 @@ const ChatForm = ({ prompt, setPrompt, onSubmit, handleFileUpload }) => {
return (
<footer className="z-10 fixed bottom-0 left-0 right-0 bg-slate-100 border-t-2">
<div className="container max-w-2xl mx-auto p-5 pb-8">
<Metrics
startedAt={metrics.startedAt}
firstMessageAt={metrics.firstMessageAt}
completedAt={metrics.completedAt}
completion={completion} />

<form className="w-full flex" onSubmit={handleSubmit}>
<UploadButton
uploader={uploader}
Expand Down
34 changes: 34 additions & 0 deletions app/components/Metrics.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { countTokens } from "../src/tokenizer";

export default function Metrics({ startedAt, firstMessageAt, completedAt, completion }) {
const timeToFirstToken = firstMessageAt && startedAt ? (new Date(firstMessageAt) - new Date(startedAt)) / 1000.0 : null;
const tokenCount = completion && countTokens(completion);
const runningDuration = firstMessageAt ? ((completedAt ? new Date(completedAt) : new Date()) - new Date(firstMessageAt)) / 1000.0 : null;
const tokensPerSecond = tokenCount > 0 && runningDuration > 0 && tokenCount / runningDuration;

return (
<dl className="tabular-nums pb-6" style={{
display: 'grid',
gridTemplateColumns: 'repeat(8, auto)',
gridTemplateAreas:
'"v1 k1 v2 k2 v3 k3 v4 k4"'
}}>
{<>
<dt title="Time to first token" className="text-gray-500" style={{ gridArea: 'k1' }}>sec to first token</dt>
<dd className="text-right pr-4" style={{ gridArea: 'v1' }}>{timeToFirstToken ? timeToFirstToken.toFixed(2) : "—"}</dd>
</>}
{<>
<dt title="Throughput" className="text-gray-500" style={{ gridArea: 'k2' }}>tokens / sec</dt>
<dd className="text-right pr-4" style={{ gridArea: 'v2' }}>{tokensPerSecond ? tokensPerSecond.toFixed(2) : "—"}</dd>
</>}
{<>
<dt title="Token count" className="text-gray-500" style={{ gridArea: 'k3' }}>tokens</dt>
<dd className="text-right pr-4" style={{ gridArea: 'v3' }}>{tokenCount || "—"}</dd>
</>}
{<>
<dt title="Run time" className="text-gray-500" style={{ gridArea: 'k4' }}>sec</dt>
<dd className="text-right pr-4" style={{ gridArea: 'v4' }}>{Math.max(runningDuration, 0).toFixed(2)}</dd>
</>}
</dl>
);
};
70 changes: 34 additions & 36 deletions app/components/SlideOver.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export default function SlideOver({
setTopP,
maxTokens,
setMaxTokens,
versions,
models,
size,
setSize,
handleSubmit,
Expand Down Expand Up @@ -108,44 +108,42 @@ export default function SlideOver({
leaveTo="opacity-0"
>
<Listbox.Options className="absolute mt-1 max-h-60 w-full shadow-md overflow-auto border-gray-700 rounded-md bg-white py-1 text-base ring-1 ring-black ring-opacity-5 focus:outline-none sm:text-sm">
{versions
? versions.map(
(version, versionIdx) => (
<Listbox.Option
key={versionIdx}
className={({ active }) =>
`relative cursor-default select-none py-2 pl-10 pr-4 ${
active
? "bg-gray-100 text-gray-900"
: "text-gray-900"
}`
}
value={version}
>
{({ selected }) => (
<>
<span
className={`block truncate ${
selected
? "font-medium"
: "font-normal"
{models
? models.map(
(model, modelIdx) => (
<Listbox.Option
key={modelIdx}
className={({ active }) =>
`relative cursor-default select-none py-2 pl-10 pr-4 ${active
? "bg-gray-100 text-gray-900"
: "text-gray-900"
}`
}
value={model}
>
{({ selected }) => (
<>
<span
className={`block truncate ${selected
? "font-medium"
: "font-normal"
}`}
>
{version.name}
>
{model.name}
</span>
{selected ? (
<span className="absolute inset-y-0 left-0 flex items-center pl-3 text-gray-600">
<CheckIcon
className="h-5 w-5"
aria-hidden="true"
/>
</span>
{selected ? (
<span className="absolute inset-y-0 left-0 flex items-center pl-3 text-gray-600">
<CheckIcon
className="h-5 w-5"
aria-hidden="true"
/>
</span>
) : null}
</>
)}
</Listbox.Option>
)
) : null}
</>
)}
</Listbox.Option>
)
)
: null}
</Listbox.Options>
</Transition>
Expand Down
Loading