-
-
Notifications
You must be signed in to change notification settings - Fork 111
/
Copy pathFunctionCalling.cs
102 lines (87 loc) · 3.21 KB
/
FunctionCalling.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
using UnityEngine;
using LLMUnity;
using UnityEngine.UI;
using System.Collections.Generic;
using System.Reflection;
namespace LLMUnitySamples
{
public static class Functions
{
static System.Random random = new System.Random();
public static string Weather()
{
string[] weather = new string[]{"sunny", "rainy", "cloudy", "snowy"};
return "The weather is " + weather[random.Next(weather.Length)];
}
public static string Time()
{
return "The time is " + random.Next(24).ToString("D2") + ":" + random.Next(60).ToString("D2");
}
public static string Emotion()
{
string[] emotion = new string[]{"happy", "sad", "exhilarated", "ok"};
return "I am feeling " + emotion[random.Next(emotion.Length)];
}
}
public class FunctionCalling : MonoBehaviour
{
public LLMCharacter llmCharacter;
public InputField playerText;
public Text AIText;
void Start()
{
playerText.onSubmit.AddListener(onInputFieldSubmit);
playerText.Select();
llmCharacter.grammarString = MultipleChoiceGrammar();
}
string[] GetFunctionNames()
{
List<string> functionNames = new List<string>();
foreach (var function in typeof(Functions).GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)) functionNames.Add(function.Name);
return functionNames.ToArray();
}
string MultipleChoiceGrammar()
{
return "root ::= (\"" + string.Join("\" | \"", GetFunctionNames()) + "\")";
}
string ConstructPrompt(string message)
{
string prompt = "Which of the following choices matches best the input?\n\n";
prompt += "Input:" + message + "\n\n";
prompt += "Choices:\n";
foreach(string functionName in GetFunctionNames()) prompt += $"- {functionName}\n";
prompt += "\nAnswer directly with the choice";
return prompt;
}
string CallFunction(string functionName)
{
return (string) typeof(Functions).GetMethod(functionName).Invoke(null, null);
}
async void onInputFieldSubmit(string message)
{
playerText.interactable = false;
string functionName = await llmCharacter.Chat(ConstructPrompt(message));
string result = CallFunction(functionName);
AIText.text = $"Calling {functionName}\n{result}";
playerText.interactable = true;
}
public void CancelRequests()
{
llmCharacter.CancelRequests();
}
public void ExitGame()
{
Debug.Log("Exit button clicked");
Application.Quit();
}
bool onValidateWarning = true;
void OnValidate()
{
if (onValidateWarning && !llmCharacter.remote && llmCharacter.llm != null && llmCharacter.llm.model == "")
{
Debug.LogWarning($"Please select a model in the {llmCharacter.llm.gameObject.name} GameObject!");
onValidateWarning = false;
}
}
}
}