From 5d79fbcc4d6eb5364f532360d71db0b1924973fd Mon Sep 17 00:00:00 2001 From: jhen Date: Sat, 12 Aug 2023 10:02:30 +0800 Subject: [PATCH 1/8] server : implement json-schema-to-grammar.mjs by follow python impl --- .../server/public/json-schema-to-grammar.mjs | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/server/public/json-schema-to-grammar.mjs diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs new file mode 100644 index 00000000000..bb0b6fc7d11 --- /dev/null +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -0,0 +1,113 @@ +const SPACE_RULE = '" "?'; + +const PRIMITIVE_RULES = { + boolean: '("true" | "false") space', + number: '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + integer: '("-"? ([0-9] | [1-9] [0-9]*)) space', + string: ` "\\"" ( + [^"\\\\] | + "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\\"" space`, + null: '"null" space', +}; + +const INVALID_RULE_CHARS_RE = /[^a-zA-Z0-9-]+/g; +const GRAMMAR_LITERAL_ESCAPE_RE = /[\r\n"]/g; +const GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'}; + +export class SchemaConverter { + constructor(propOrder) { + this._propOrder = propOrder || {}; + this._rules = new Map(); + this._rules.set('space', SPACE_RULE); + } + + _formatLiteral(literal) { + const escaped = JSON.stringify(literal).replace( + GRAMMAR_LITERAL_ESCAPE_RE, + m => GRAMMAR_LITERAL_ESCAPES[m] + ); + return `"${escaped}"`; + } + + _addRule(name, rule) { + let escName = name.replace(INVALID_RULE_CHARS_RE, '-'); + let key = escName; + + if (this._rules.has(escName)) { + if (this._rules.get(escName) === rule) { + return key; + } + + let i = 0; + while (this._rules.has(`${escName}${i}`)) { + i += 1; + } + key = `${escName}${i}`; + } + + this._rules.set(key, rule); + return key; + } + + visit(schema, name) { + const schemaType = schema.type; + const ruleName = name || 'root'; + + if (schema.oneOf || schema.anyOf) { + const rule = (schema.oneOf || schema.anyOf).map((altSchema, i) => + this.visit(altSchema, `${name}${name ? "-" : ""}${i}`) + ).join(' | '); + + return this._addRule(ruleName, rule); + } else if ('const' in schema) { + return this._addRule(ruleName, this._formatLiteral(schema.const)); + } else if ('enum' in schema) { + const rule = schema.enum.map(v => this._formatLiteral(v)).join(' | '); + return this._addRule(ruleName, rule); + } else if (schemaType === 'object' && 'properties' in schema) { + // TODO: `required` keyword (from python implementation) + const propOrder = this._propOrder; + const propPairs = Object.entries(schema.properties).sort((a, b) => { + // sort by position in prop_order (if specified) then by key + const orderA = propOrder.hasOwnProperty(a[0]) ? propOrder[a[0]] : propOrder.length; + const orderB = propOrder.hasOwnProperty(b[0]) ? propOrder[b[0]] : propOrder.length; + return orderA - orderB || a[0].localeCompare(b[0]); + }); + + let rule = '"{" space'; + for (let i = 0; i < propPairs.length; i++) { + const [propName, propSchema] = propPairs[i]; + const propRuleName = this.visit(propSchema, `${name}${name ? "-" : ""}${propName}`); + if (i > 0) { + rule += ' "," space'; + } + rule += ` ${this._formatLiteral(propName)} space ":" space ${propRuleName}`; + } + rule += ' "}" space'; + + return this._addRule(ruleName, rule); + } else if (schemaType === 'array' && 'items' in schema) { + // TODO `prefixItems` keyword (from python implementation) + const itemRuleName = this.visit(schema.items, `${name}${name ? "-" : ""}item`); + const rule = `"[" space (${itemRuleName} ("," space ${itemRuleName})*)? "]" space`; + return this._addRule(ruleName, rule); + } else { + if (!PRIMITIVE_RULES[schemaType]) { + throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); + } + return this._addRule( + ruleName === 'root' ? 'root' : schemaType, + PRIMITIVE_RULES[schemaType] + ); + } + } + + formatGrammar() { + let grammar = ''; + this._rules.forEach((rule, name) => { + grammar += `${name} ::= ${rule}\n`; + }); + return grammar; + } +} From c58ff992dcdba25931a9b38bbc4f59c7138d272f Mon Sep 17 00:00:00 2001 From: jhen Date: Sat, 12 Aug 2023 10:15:31 +0800 Subject: [PATCH 2/8] server : add grammar support in chat.mjs --- examples/server/chat.mjs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/examples/server/chat.mjs b/examples/server/chat.mjs index 8269e259273..87f4d2926ca 100644 --- a/examples/server/chat.mjs +++ b/examples/server/chat.mjs @@ -1,5 +1,34 @@ import * as readline from 'node:readline' import { stdin, stdout } from 'node:process' +import { readFileSync } from 'node:fs' +import { SchemaConverter } from './public/json-schema-to-grammar.mjs' + +const args = process.argv.slice(2); +const grammarJsonSchemaFile = args.find( + (_, index) => args[index - 1] === "--grammar-json-schema" +); +const grammarFile = args.find((_, index) => args[index - 1] === "--grammar"); + +// Example usage: function,arguments +const grammarJsonSchemaPropOrder = args.find( + (_, index) => args[index - 1] === "--grammar-json-schema-prop-order" +); +const propOrder = grammarJsonSchemaPropOrder + ? grammarJsonSchemaPropOrder + .split(",") + .reduce((acc, cur, index) => ({ ...acc, [cur]: index }), {}) + : {}; + +let grammar = null +if (grammarJsonSchemaFile) { + const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8')) + const converter = new SchemaConverter(propOrder) + converter.visit(schema, '') + grammar = converter.formatGrammar() +} +if (grammarFile) { + grammar = readFileSync(grammarFile, 'utf-8') +} const API_URL = 'http://127.0.0.1:8080' @@ -48,6 +77,7 @@ async function chat_completion(question) { n_keep: n_keep, n_predict: 256, stop: ["\n### Human:"], // stop completion after generating this + grammar, stream: true, }) }) From 55a86adc2426b5cf84ad4442ee9b7f073c703fd6 Mon Sep 17 00:00:00 2001 From: jhen Date: Sat, 12 Aug 2023 10:41:56 +0800 Subject: [PATCH 3/8] server : implement grammer param in the UI --- examples/server/public/index.html | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/examples/server/public/index.html b/examples/server/public/index.html index de41da187a4..c43becc42e8 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -141,6 +141,7 @@ } from '/index.js'; import { llama } from '/completion.js'; + import { SchemaConverter } from '/json-schema-to-grammar.mjs'; const session = signal({ prompt: "This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.", @@ -166,6 +167,7 @@ mirostat: 0, // 0/1/2 mirostat_tau: 5, // target entropy mirostat_eta: 0.1, // learning rate + grammar: null, }) const llamaStats = signal(null) @@ -304,6 +306,26 @@ const updateParamsFloat = (el) => params.value = { ...params.value, [el.target.name]: parseFloat(el.target.value) } const updateParamsInt = (el) => params.value = { ...params.value, [el.target.name]: Math.floor(parseFloat(el.target.value)) } + const grammarJsonSchemaPropOrder = signal('') + const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value + const convertJSONSchemaGrammar = () => { + try { + const schema = JSON.parse(params.value.grammar) + const converter = new SchemaConverter( + grammarJsonSchemaPropOrder.value + .split(',') + .reduce((acc, cur, i) => ({...acc, [cur.trim()]: i}), {}) + ) + converter.visit(schema, '') + params.value = { + ...params.value, + grammar: converter.formatGrammar(), + } + } catch (e) { + alert(`Convert failed: ${e.message}`) + } + } + const FloatField = ({label, max, min, name, step, value}) => { return html`
@@ -355,6 +377,13 @@