11'use strict'
22
3- const BaseAwsSdkPlugin = require ( '../base' )
4- const log = require ( '../../../dd-trace/src/log' )
3+ const log = require ( '../../../../dd-trace/src/log' )
4+
5+ const MODEL_TYPE_IDENTIFIERS = [
6+ 'foundation-model/' ,
7+ 'custom-model/' ,
8+ 'provisioned-model/' ,
9+ 'imported-module/' ,
10+ 'prompt/' ,
11+ 'endpoint/' ,
12+ 'inference-profile/' ,
13+ 'default-prompt-router/'
14+ ]
515
616const PROVIDER = {
717 AI21 : 'AI21' ,
@@ -13,44 +23,6 @@ const PROVIDER = {
1323 MISTRAL : 'MISTRAL'
1424}
1525
16- const enabledOperations = [ 'invokeModel' ]
17-
18- class BedrockRuntime extends BaseAwsSdkPlugin {
19- static get id ( ) { return 'bedrock runtime' }
20-
21- isEnabled ( request ) {
22- const operation = request . operation
23- if ( ! enabledOperations . includes ( operation ) ) {
24- return false
25- }
26-
27- return super . isEnabled ( request )
28- }
29-
30- generateTags ( params , operation , response ) {
31- let tags = { }
32- let modelName = ''
33- let modelProvider = ''
34- const modelMeta = params . modelId . split ( '.' )
35- if ( modelMeta . length === 2 ) {
36- [ modelProvider , modelName ] = modelMeta
37- modelProvider = modelProvider . toUpperCase ( )
38- } else {
39- [ , modelProvider , modelName ] = modelMeta
40- modelProvider = modelProvider . toUpperCase ( )
41- }
42-
43- const shouldSetChoiceIds = modelProvider === PROVIDER . COHERE && ! modelName . includes ( 'embed' )
44-
45- const requestParams = extractRequestParams ( params , modelProvider )
46- const textAndResponseReason = extractTextAndResponseReason ( response , modelProvider , modelName , shouldSetChoiceIds )
47-
48- tags = buildTagsFromParams ( requestParams , textAndResponseReason , modelProvider , modelName , operation )
49-
50- return tags
51- }
52- }
53-
5426class Generation {
5527 constructor ( { message = '' , finishReason = '' , choiceId = '' } = { } ) {
5628 // stringify message as it could be a single generated message as well as a list of embeddings
@@ -65,18 +37,19 @@ class RequestParams {
6537 prompt = '' ,
6638 temperature = undefined ,
6739 topP = undefined ,
40+ topK = undefined ,
6841 maxTokens = undefined ,
6942 stopSequences = [ ] ,
7043 inputType = '' ,
7144 truncate = '' ,
7245 stream = '' ,
7346 n = undefined
7447 } = { } ) {
75- // TODO: set a truncation limit to prompt
7648 // stringify prompt as it could be a single prompt as well as a list of message objects
7749 this . prompt = typeof prompt === 'string' ? prompt : JSON . stringify ( prompt ) || ''
7850 this . temperature = temperature !== undefined ? temperature : undefined
7951 this . topP = topP !== undefined ? topP : undefined
52+ this . topK = topK !== undefined ? topK : undefined
8053 this . maxTokens = maxTokens !== undefined ? maxTokens : undefined
8154 this . stopSequences = stopSequences || [ ]
8255 this . inputType = inputType || ''
@@ -86,11 +59,53 @@ class RequestParams {
8659 }
8760}
8861
62+ function parseModelId ( modelId ) {
63+ // Best effort to extract the model provider and model name from the bedrock model ID.
64+ // modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
65+ // 1. Base model: "{model_provider}.{model_name}"
66+ // 2. Cross-region model: "{region}.{model_provider}.{model_name}"
67+ // 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
68+ // a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
69+ // b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
70+ // c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
71+ // d. Imported model: ARN prefix + "imported-module/{model-id}"
72+ // e. Prompt management: ARN prefix + "prompt/{prompt-id}"
73+ // f. Sagemaker: ARN prefix + "endpoint/{model-id}"
74+ // g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
75+ // h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
76+ // If model provider cannot be inferred from the modelId formatting, then default to "custom"
77+ modelId = modelId . toLowerCase ( )
78+ if ( ! modelId . startsWith ( 'arn:aws' ) ) {
79+ const modelMeta = modelId . split ( '.' )
80+ if ( modelMeta . length < 2 ) {
81+ return { modelProvider : 'custom' , modelName : modelMeta [ 0 ] }
82+ }
83+ return { modelProvider : modelMeta [ modelMeta . length - 2 ] , modelName : modelMeta [ modelMeta . length - 1 ] }
84+ }
85+
86+ for ( const identifier of MODEL_TYPE_IDENTIFIERS ) {
87+ if ( ! modelId . includes ( identifier ) ) {
88+ continue
89+ }
90+ modelId = modelId . split ( identifier ) . pop ( )
91+ if ( [ 'foundation-model/' , 'custom-model/' ] . includes ( identifier ) ) {
92+ const modelMeta = modelId . split ( '.' )
93+ if ( modelMeta . length < 2 ) {
94+ return { modelProvider : 'custom' , modelName : modelId }
95+ }
96+ return { modelProvider : modelMeta [ modelMeta . length - 2 ] , modelName : modelMeta [ modelMeta . length - 1 ] }
97+ }
98+ return { modelProvider : 'custom' , modelName : modelId }
99+ }
100+
101+ return { modelProvider : 'custom' , modelName : 'custom' }
102+ }
103+
89104function extractRequestParams ( params , provider ) {
90105 const requestBody = JSON . parse ( params . body )
91106 const modelId = params . modelId
92107
93- switch ( provider ) {
108+ switch ( provider . toUpperCase ( ) ) {
94109 case PROVIDER . AI21 : {
95110 let userPrompt = requestBody . prompt
96111 if ( modelId . includes ( 'jamba' ) ) {
@@ -176,11 +191,11 @@ function extractRequestParams (params, provider) {
176191 }
177192}
178193
179- function extractTextAndResponseReason ( response , provider , modelName , shouldSetChoiceIds ) {
194+ function extractTextAndResponseReason ( response , provider , modelName ) {
180195 const body = JSON . parse ( Buffer . from ( response . body ) . toString ( 'utf8' ) )
181-
196+ const shouldSetChoiceIds = provider . toUpperCase ( ) === PROVIDER . COHERE && ! modelName . includes ( 'embed' )
182197 try {
183- switch ( provider ) {
198+ switch ( provider . toUpperCase ( ) ) {
184199 case PROVIDER . AI21 : {
185200 if ( modelName . includes ( 'jamba' ) ) {
186201 const generations = body . choices || [ ]
@@ -262,34 +277,11 @@ function extractTextAndResponseReason (response, provider, modelName, shouldSetC
262277 return new Generation ( )
263278}
264279
265- function buildTagsFromParams ( requestParams , textAndResponseReason , modelProvider , modelName , operation ) {
266- const tags = { }
267-
268- // add request tags
269- tags [ 'resource.name' ] = operation
270- tags [ 'aws.bedrock.request.model' ] = modelName
271- tags [ 'aws.bedrock.request.model_provider' ] = modelProvider
272- tags [ 'aws.bedrock.request.prompt' ] = requestParams . prompt
273- tags [ 'aws.bedrock.request.temperature' ] = requestParams . temperature
274- tags [ 'aws.bedrock.request.top_p' ] = requestParams . topP
275- tags [ 'aws.bedrock.request.max_tokens' ] = requestParams . maxTokens
276- tags [ 'aws.bedrock.request.stop_sequences' ] = requestParams . stopSequences
277- tags [ 'aws.bedrock.request.input_type' ] = requestParams . inputType
278- tags [ 'aws.bedrock.request.truncate' ] = requestParams . truncate
279- tags [ 'aws.bedrock.request.stream' ] = requestParams . stream
280- tags [ 'aws.bedrock.request.n' ] = requestParams . n
281-
282- // add response tags
283- if ( modelName . includes ( 'embed' ) ) {
284- tags [ 'aws.bedrock.response.embedding_length' ] = textAndResponseReason . message . length
285- }
286- if ( textAndResponseReason . choiceId ) {
287- tags [ 'aws.bedrock.response.choices.id' ] = textAndResponseReason . choiceId
288- }
289- tags [ 'aws.bedrock.response.choices.text' ] = textAndResponseReason . message
290- tags [ 'aws.bedrock.response.choices.finish_reason' ] = textAndResponseReason . finishReason
291-
292- return tags
280+ module . exports = {
281+ Generation,
282+ RequestParams,
283+ parseModelId,
284+ extractRequestParams,
285+ extractTextAndResponseReason,
286+ PROVIDER
293287}
294-
295- module . exports = BedrockRuntime
0 commit comments