-
-
Notifications
You must be signed in to change notification settings - Fork 111
/
Copy pathRAG_Sample.cs
110 lines (99 loc) · 3.18 KB
/
RAG_Sample.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
using System.Collections.Generic;
using UnityEngine;
using System.IO;
using System.Diagnostics;
using Debug = UnityEngine.Debug;
using UnityEngine.UI;
using LLMUnity;
using System.Threading.Tasks;
using System.Text.RegularExpressions;
namespace LLMUnitySamples
{
public class RAGSample : MonoBehaviour
{
public RAG rag;
public InputField playerText;
public Text AIText;
public TextAsset HamletText;
List<string> phrases;
string ragPath = "RAGSample.zip";
async void Start()
{
CheckLLMs(false);
playerText.interactable = false;
LoadPhrases();
await CreateEmbeddings();
playerText.onSubmit.AddListener(onInputFieldSubmit);
AIReplyComplete();
}
public void LoadPhrases()
{
phrases = RAGUtils.ReadGutenbergFile(HamletText.text)["HAMLET"];
}
public async Task CreateEmbeddings()
{
bool loaded = await rag.Load(ragPath);
if (!loaded)
{
#if UNITY_EDITOR
// build the embeddings
playerText.text += $"Creating Embeddings (only once)...\n";
Stopwatch stopwatch = new Stopwatch();
stopwatch.Start();
foreach (string phrase in phrases) await rag.Add(phrase);
stopwatch.Stop();
Debug.Log($"embedded {rag.Count()} phrases in {stopwatch.Elapsed.TotalMilliseconds / 1000f} secs");
// store the embeddings
rag.Save(ragPath);
#else
// if in play mode throw an error
throw new System.Exception("The embeddings could not be found!");
#endif
}
}
protected async virtual void onInputFieldSubmit(string message)
{
playerText.interactable = false;
AIText.text = "...";
(string[] similarPhrases, float[] distances) = await rag.Search(message, 1);
AIText.text = similarPhrases[0];
}
public void SetAIText(string text)
{
AIText.text = text;
}
public void AIReplyComplete()
{
playerText.interactable = true;
playerText.Select();
playerText.text = "";
}
public void ExitGame()
{
Debug.Log("Exit button clicked");
Application.Quit();
}
protected void CheckLLM(LLMCaller llmCaller, bool debug)
{
if (!llmCaller.remote && llmCaller.llm != null && llmCaller.llm.model == "")
{
string error = $"Please select a llm model in the {llmCaller.llm.gameObject.name} GameObject!";
if (debug) Debug.LogWarning(error);
else throw new System.Exception(error);
}
}
protected virtual void CheckLLMs(bool debug)
{
CheckLLM(rag.search.llmEmbedder, debug);
}
bool onValidateWarning = true;
void OnValidate()
{
if (onValidateWarning)
{
CheckLLMs(true);
onValidateWarning = false;
}
}
}
}