13
13
// limitations under the License.
14
14
15
15
import 'package:flutter/material.dart' ;
16
+ import 'package:flutter/services.dart' ;
16
17
import 'package:flutter_markdown/flutter_markdown.dart' ;
17
18
import 'package:google_generative_ai/google_generative_ai.dart' ;
18
19
@@ -69,10 +70,13 @@ class ChatWidget extends StatefulWidget {
69
70
70
71
class _ChatWidgetState extends State <ChatWidget > {
71
72
late final GenerativeModel _model;
73
+ late final GenerativeModel _visionModel;
72
74
late final ChatSession _chat;
73
75
final ScrollController _scrollController = ScrollController ();
74
76
final TextEditingController _textController = TextEditingController ();
75
77
final FocusNode _textFieldFocus = FocusNode ();
78
+ final List <({Image ? image, String ? text, bool fromUser})> _generatedContent =
79
+ < ({Image ? image, String ? text, bool fromUser})> [];
76
80
bool _loading = false ;
77
81
static const _apiKey = String .fromEnvironment ('API_KEY' );
78
82
@@ -83,6 +87,10 @@ class _ChatWidgetState extends State<ChatWidget> {
83
87
model: 'gemini-pro' ,
84
88
apiKey: _apiKey,
85
89
);
90
+ _visionModel = GenerativeModel (
91
+ model: 'gemini-pro-vision' ,
92
+ apiKey: _apiKey,
93
+ );
86
94
_chat = _model.startChat ();
87
95
}
88
96
@@ -132,17 +140,14 @@ class _ChatWidgetState extends State<ChatWidget> {
132
140
? ListView .builder (
133
141
controller: _scrollController,
134
142
itemBuilder: (context, idx) {
135
- var content = _chat.history.toList ()[idx];
136
- var text = content.parts
137
- .whereType <TextPart >()
138
- .map <String >((e) => e.text)
139
- .join ('' );
143
+ var content = _generatedContent[idx];
140
144
return MessageWidget (
141
- text: text,
142
- isFromUser: content.role == 'user' ,
145
+ text: content.text,
146
+ image: content.image,
147
+ isFromUser: content.fromUser,
143
148
);
144
149
},
145
- itemCount: _chat.history .length,
150
+ itemCount: _generatedContent .length,
146
151
)
147
152
: ListView (
148
153
children: const [
@@ -171,6 +176,19 @@ class _ChatWidgetState extends State<ChatWidget> {
171
176
const SizedBox .square (
172
177
dimension: 15 ,
173
178
),
179
+ IconButton (
180
+ onPressed: ! _loading
181
+ ? () async {
182
+ _sendImagePrompt (_textController.text);
183
+ }
184
+ : null ,
185
+ icon: Icon (
186
+ Icons .image,
187
+ color: _loading
188
+ ? Theme .of (context).colorScheme.secondary
189
+ : Theme .of (context).colorScheme.primary,
190
+ ),
191
+ ),
174
192
if (! _loading)
175
193
IconButton (
176
194
onPressed: () async {
@@ -191,16 +209,71 @@ class _ChatWidgetState extends State<ChatWidget> {
191
209
);
192
210
}
193
211
212
+ Future <void > _sendImagePrompt (String message) async {
213
+ setState (() {
214
+ _loading = true ;
215
+ });
216
+ try {
217
+ ByteData catBytes = await rootBundle.load ('assets/images/cat.jpg' );
218
+ ByteData sconeBytes = await rootBundle.load ('assets/images/scones.jpg' );
219
+ final content = [
220
+ Content .multi ([
221
+ TextPart (message),
222
+ // The only accepted mime types are image/*.
223
+ DataPart ('image/jpeg' , catBytes.buffer.asUint8List ()),
224
+ DataPart ('image/jpeg' , sconeBytes.buffer.asUint8List ()),
225
+ ])
226
+ ];
227
+ _generatedContent.add ((
228
+ image: Image .asset ("assets/images/cat.jpg" ),
229
+ text: message,
230
+ fromUser: true
231
+ ));
232
+ _generatedContent.add ((
233
+ image: Image .asset ("assets/images/scones.jpg" ),
234
+ text: null ,
235
+ fromUser: true
236
+ ));
237
+
238
+ var response = await _visionModel.generateContent (content);
239
+ var text = response.text;
240
+ _generatedContent.add ((image: null , text: text, fromUser: false ));
241
+
242
+ if (text == null ) {
243
+ _showError ('No response from API.' );
244
+ return ;
245
+ } else {
246
+ setState (() {
247
+ _loading = false ;
248
+ _scrollDown ();
249
+ });
250
+ }
251
+ } catch (e) {
252
+ _showError (e.toString ());
253
+ setState (() {
254
+ _loading = false ;
255
+ });
256
+ } finally {
257
+ _textController.clear ();
258
+ setState (() {
259
+ _loading = false ;
260
+ });
261
+ _textFieldFocus.requestFocus ();
262
+ }
263
+ }
264
+
194
265
Future <void > _sendChatMessage (String message) async {
195
266
setState (() {
196
267
_loading = true ;
197
268
});
198
269
199
270
try {
271
+ _generatedContent.add ((image: null , text: message, fromUser: true ));
200
272
var response = await _chat.sendMessage (
201
273
Content .text (message),
202
274
);
203
275
var text = response.text;
276
+ _generatedContent.add ((image: null , text: text, fromUser: false ));
204
277
205
278
if (text == null ) {
206
279
_showError ('No response from API.' );
@@ -249,12 +322,14 @@ class _ChatWidgetState extends State<ChatWidget> {
249
322
}
250
323
251
324
class MessageWidget extends StatelessWidget {
252
- final String text;
325
+ final Image ? image;
326
+ final String ? text;
253
327
final bool isFromUser;
254
328
255
329
const MessageWidget ({
256
330
super .key,
257
- required this .text,
331
+ this .image,
332
+ this .text,
258
333
required this .isFromUser,
259
334
});
260
335
@@ -265,25 +340,23 @@ class MessageWidget extends StatelessWidget {
265
340
isFromUser ? MainAxisAlignment .end : MainAxisAlignment .start,
266
341
children: [
267
342
Flexible (
268
- child: Container (
269
- constraints: const BoxConstraints (maxWidth: 600 ),
270
- decoration: BoxDecoration (
271
- color: isFromUser
272
- ? Theme .of (context).colorScheme.primaryContainer
273
- : Theme .of (context).colorScheme.surfaceVariant,
274
- borderRadius: BorderRadius .circular (18 ),
275
- ),
276
- padding: const EdgeInsets .symmetric (
277
- vertical: 15 ,
278
- horizontal: 20 ,
279
- ),
280
- margin: const EdgeInsets .only (bottom: 8 ),
281
- child: MarkdownBody (
282
- selectable: true ,
283
- data: text,
284
- ),
285
- ),
286
- ),
343
+ child: Container (
344
+ constraints: const BoxConstraints (maxWidth: 600 ),
345
+ decoration: BoxDecoration (
346
+ color: isFromUser
347
+ ? Theme .of (context).colorScheme.primaryContainer
348
+ : Theme .of (context).colorScheme.surfaceVariant,
349
+ borderRadius: BorderRadius .circular (18 ),
350
+ ),
351
+ padding: const EdgeInsets .symmetric (
352
+ vertical: 15 ,
353
+ horizontal: 20 ,
354
+ ),
355
+ margin: const EdgeInsets .only (bottom: 8 ),
356
+ child: Column (children: [
357
+ if (text case final text? ) MarkdownBody (data: text),
358
+ if (image case final image? ) image,
359
+ ]))),
287
360
],
288
361
);
289
362
}
0 commit comments