-
-
Notifications
You must be signed in to change notification settings - Fork 19.6k
/
Copy pathcore.ts
145 lines (122 loc) · 4.73 KB
/
core.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import { LLM, type BaseLLMParams } from '@langchain/core/language_models/llms'
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'
import { GenerationChunk } from '@langchain/core/outputs'
import type ReplicateInstance from 'replicate'
export interface ReplicateInput {
model: `${string}/${string}` | `${string}/${string}:${string}`
input?: {
// different models accept different inputs
[key: string]: string | number | boolean
}
apiKey?: string
promptKey?: string
}
export class Replicate extends LLM implements ReplicateInput {
lc_serializable = true
model: ReplicateInput['model']
input: ReplicateInput['input']
apiKey: string
promptKey?: string
constructor(fields: ReplicateInput & BaseLLMParams) {
super(fields)
const apiKey = fields?.apiKey
if (!apiKey) {
throw new Error('Please set the REPLICATE_API_TOKEN')
}
this.apiKey = apiKey
this.model = fields.model
this.input = fields.input ?? {}
this.promptKey = fields.promptKey
}
_llmType() {
return 'replicate'
}
/** @ignore */
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
const replicate = await this._prepareReplicate()
const input = await this._getReplicateInput(replicate, prompt)
const output = await this.caller.callWithOptions({ signal: options.signal }, () =>
replicate.run(this.model, {
input
})
)
if (typeof output === 'string') {
return output
} else if (Array.isArray(output)) {
return output.join('')
} else {
// Note this is a little odd, but the output format is not consistent
// across models, so it makes some amount of sense.
return String(output)
}
}
async *_streamResponseChunks(
prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const replicate = await this._prepareReplicate()
const input = await this._getReplicateInput(replicate, prompt)
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () =>
replicate.stream(this.model, {
input
})
)
for await (const chunk of stream) {
if (chunk.event === 'output') {
yield new GenerationChunk({ text: chunk.data, generationInfo: chunk })
await runManager?.handleLLMNewToken(chunk.data ?? '')
}
// stream is done
if (chunk.event === 'done')
yield new GenerationChunk({
text: '',
generationInfo: { finished: true }
})
}
}
/** @ignore */
static async imports(): Promise<{
Replicate: typeof ReplicateInstance
}> {
try {
const { default: Replicate } = await import('replicate')
return { Replicate }
} catch (e) {
throw new Error('Please install replicate as a dependency with, e.g. `yarn add replicate`')
}
}
private async _prepareReplicate(): Promise<ReplicateInstance> {
const imports = await Replicate.imports()
return new imports.Replicate({
userAgent: 'flowise',
auth: this.apiKey
})
}
private async _getReplicateInput(replicate: ReplicateInstance, prompt: string) {
if (this.promptKey === undefined) {
const [modelString, versionString] = this.model.split(':')
if (versionString) {
const version = await replicate.models.versions.get(modelString.split('/')[0], modelString.split('/')[1], versionString)
const openapiSchema = version.openapi_schema
const inputProperties: { 'x-order': number | undefined }[] = (openapiSchema as any)?.components?.schemas?.Input?.properties
if (inputProperties === undefined) {
this.promptKey = 'prompt'
} else {
const sortedInputProperties = Object.entries(inputProperties).sort(([_keyA, valueA], [_keyB, valueB]) => {
const orderA = valueA['x-order'] || 0
const orderB = valueB['x-order'] || 0
return orderA - orderB
})
this.promptKey = sortedInputProperties[0][0] ?? 'prompt'
}
} else {
this.promptKey = 'prompt'
}
}
return {
[this.promptKey!]: prompt,
...this.input
}
}
}