-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathCreateChatToolCompletionWithFeedback.scala
112 lines (91 loc) · 3.65 KB
/
CreateChatToolCompletionWithFeedback.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package io.cequence.openaiscala.examples
import io.cequence.openaiscala.domain.AssistantTool.FunctionTool
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import play.api.libs.json.Json
import scala.concurrent.Future
// based on: https://platform.openai.com/docs/guides/function-calling
object CreateChatToolCompletionWithFeedback extends Example {
private val modelId = ModelId.gpt_4_turbo_preview
val introMessages = Seq(
SystemMessage("You are a helpful assistant."),
UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")
)
// as a param type we can use "number", "string", "boolean", "object", "array", and "null"
val tools = Seq(
FunctionTool(
name = "get_current_weather",
description = Some("Get the current weather in a given location"),
parameters = Map(
"type" -> "object",
"properties" -> Map(
"location" -> Map(
"type" -> "string",
"description" -> "The city and state, e.g. San Francisco, CA"
),
"unit" -> Map(
"type" -> "string",
"enum" -> Seq("celsius", "fahrenheit")
)
),
"required" -> Seq("location")
)
)
)
override protected def run: Future[_] =
for {
assistantToolResponse <- service.createChatToolCompletion(
messages = introMessages,
tools = tools,
responseToolChoice = None, // means "auto"
settings = CreateChatCompletionSettings(modelId)
)
assistantToolMessage = assistantToolResponse.choices.head.message
toolCalls = assistantToolMessage.tool_calls
// we can handle only function calls (that will change in future)
functionCalls = toolCalls.collect { case (toolCallId, x: FunctionCallSpec) =>
(toolCallId, x)
}
available_functions = Map("get_current_weather" -> getCurrentWeather _)
toolMessages = functionCalls.map { case (toolCallId, functionCallSpec) =>
val functionName = functionCallSpec.name
val functionArgsJson = Json.parse(functionCallSpec.arguments)
// this is not very generic, but it's ok for a demo
val functionResponse = available_functions.get(functionName) match {
case Some(functionToCall) =>
functionToCall(
(functionArgsJson \ "location").as[String],
(functionArgsJson \ "unit").asOpt[String]
)
case _ => throw new IllegalArgumentException(s"Unknown function: $functionName")
}
ToolMessage(
tool_call_id = toolCallId,
content = Some(functionResponse.toString),
name = functionName
)
}
messages = introMessages ++ Seq(assistantToolMessage) ++ toolMessages
finalAssistantResponse <- service.createChatCompletion(
messages = messages,
settings = CreateChatCompletionSettings(modelId)
)
} yield {
println(finalAssistantResponse.choices.head.message.content)
}
// unit is ignored here
private def getCurrentWeather(
location: String,
unit: Option[String]
) =
location.toLowerCase() match {
case loc if loc.contains("tokyo") =>
Json.obj("location" -> "Tokyo", "temperature" -> "10", "unit" -> "celsius")
case loc if loc.contains("san francisco") =>
Json.obj("location" -> "San Francisco", "temperature" -> "72", "unit" -> "fahrenheit")
case loc if loc.contains("paris") =>
Json.obj("location" -> "Paris", "temperature" -> "22", "unit" -> "celsius")
case _ =>
Json.obj("location" -> location, "temperature" -> "unknown")
}
}