Skip to content

Commit 7c5d18f

Browse files
authored
Add vision model sample to flutter_app (#72)
* Add vision model sample to flutter_app * fix the analyze error * tweak after merge * address review comment
1 parent 632bc74 commit 7c5d18f

File tree

5 files changed

+105
-30
lines changed

5 files changed

+105
-30
lines changed
17.4 KB
Loading
213 KB
Loading

samples/flutter_app/lib/main.dart

+102-29
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
import 'package:flutter/material.dart';
16+
import 'package:flutter/services.dart';
1617
import 'package:flutter_markdown/flutter_markdown.dart';
1718
import 'package:google_generative_ai/google_generative_ai.dart';
1819

@@ -69,10 +70,13 @@ class ChatWidget extends StatefulWidget {
6970

7071
class _ChatWidgetState extends State<ChatWidget> {
7172
late final GenerativeModel _model;
73+
late final GenerativeModel _visionModel;
7274
late final ChatSession _chat;
7375
final ScrollController _scrollController = ScrollController();
7476
final TextEditingController _textController = TextEditingController();
7577
final FocusNode _textFieldFocus = FocusNode();
78+
final List<({Image? image, String? text, bool fromUser})> _generatedContent =
79+
<({Image? image, String? text, bool fromUser})>[];
7680
bool _loading = false;
7781
static const _apiKey = String.fromEnvironment('API_KEY');
7882

@@ -83,6 +87,10 @@ class _ChatWidgetState extends State<ChatWidget> {
8387
model: 'gemini-pro',
8488
apiKey: _apiKey,
8589
);
90+
_visionModel = GenerativeModel(
91+
model: 'gemini-pro-vision',
92+
apiKey: _apiKey,
93+
);
8694
_chat = _model.startChat();
8795
}
8896

@@ -132,17 +140,14 @@ class _ChatWidgetState extends State<ChatWidget> {
132140
? ListView.builder(
133141
controller: _scrollController,
134142
itemBuilder: (context, idx) {
135-
var content = _chat.history.toList()[idx];
136-
var text = content.parts
137-
.whereType<TextPart>()
138-
.map<String>((e) => e.text)
139-
.join('');
143+
var content = _generatedContent[idx];
140144
return MessageWidget(
141-
text: text,
142-
isFromUser: content.role == 'user',
145+
text: content.text,
146+
image: content.image,
147+
isFromUser: content.fromUser,
143148
);
144149
},
145-
itemCount: _chat.history.length,
150+
itemCount: _generatedContent.length,
146151
)
147152
: ListView(
148153
children: const [
@@ -171,6 +176,19 @@ class _ChatWidgetState extends State<ChatWidget> {
171176
const SizedBox.square(
172177
dimension: 15,
173178
),
179+
IconButton(
180+
onPressed: !_loading
181+
? () async {
182+
_sendImagePrompt(_textController.text);
183+
}
184+
: null,
185+
icon: Icon(
186+
Icons.image,
187+
color: _loading
188+
? Theme.of(context).colorScheme.secondary
189+
: Theme.of(context).colorScheme.primary,
190+
),
191+
),
174192
if (!_loading)
175193
IconButton(
176194
onPressed: () async {
@@ -191,16 +209,71 @@ class _ChatWidgetState extends State<ChatWidget> {
191209
);
192210
}
193211

212+
Future<void> _sendImagePrompt(String message) async {
213+
setState(() {
214+
_loading = true;
215+
});
216+
try {
217+
ByteData catBytes = await rootBundle.load('assets/images/cat.jpg');
218+
ByteData sconeBytes = await rootBundle.load('assets/images/scones.jpg');
219+
final content = [
220+
Content.multi([
221+
TextPart(message),
222+
// The only accepted mime types are image/*.
223+
DataPart('image/jpeg', catBytes.buffer.asUint8List()),
224+
DataPart('image/jpeg', sconeBytes.buffer.asUint8List()),
225+
])
226+
];
227+
_generatedContent.add((
228+
image: Image.asset("assets/images/cat.jpg"),
229+
text: message,
230+
fromUser: true
231+
));
232+
_generatedContent.add((
233+
image: Image.asset("assets/images/scones.jpg"),
234+
text: null,
235+
fromUser: true
236+
));
237+
238+
var response = await _visionModel.generateContent(content);
239+
var text = response.text;
240+
_generatedContent.add((image: null, text: text, fromUser: false));
241+
242+
if (text == null) {
243+
_showError('No response from API.');
244+
return;
245+
} else {
246+
setState(() {
247+
_loading = false;
248+
_scrollDown();
249+
});
250+
}
251+
} catch (e) {
252+
_showError(e.toString());
253+
setState(() {
254+
_loading = false;
255+
});
256+
} finally {
257+
_textController.clear();
258+
setState(() {
259+
_loading = false;
260+
});
261+
_textFieldFocus.requestFocus();
262+
}
263+
}
264+
194265
Future<void> _sendChatMessage(String message) async {
195266
setState(() {
196267
_loading = true;
197268
});
198269

199270
try {
271+
_generatedContent.add((image: null, text: message, fromUser: true));
200272
var response = await _chat.sendMessage(
201273
Content.text(message),
202274
);
203275
var text = response.text;
276+
_generatedContent.add((image: null, text: text, fromUser: false));
204277

205278
if (text == null) {
206279
_showError('No response from API.');
@@ -249,12 +322,14 @@ class _ChatWidgetState extends State<ChatWidget> {
249322
}
250323

251324
class MessageWidget extends StatelessWidget {
252-
final String text;
325+
final Image? image;
326+
final String? text;
253327
final bool isFromUser;
254328

255329
const MessageWidget({
256330
super.key,
257-
required this.text,
331+
this.image,
332+
this.text,
258333
required this.isFromUser,
259334
});
260335

@@ -265,25 +340,23 @@ class MessageWidget extends StatelessWidget {
265340
isFromUser ? MainAxisAlignment.end : MainAxisAlignment.start,
266341
children: [
267342
Flexible(
268-
child: Container(
269-
constraints: const BoxConstraints(maxWidth: 600),
270-
decoration: BoxDecoration(
271-
color: isFromUser
272-
? Theme.of(context).colorScheme.primaryContainer
273-
: Theme.of(context).colorScheme.surfaceVariant,
274-
borderRadius: BorderRadius.circular(18),
275-
),
276-
padding: const EdgeInsets.symmetric(
277-
vertical: 15,
278-
horizontal: 20,
279-
),
280-
margin: const EdgeInsets.only(bottom: 8),
281-
child: MarkdownBody(
282-
selectable: true,
283-
data: text,
284-
),
285-
),
286-
),
343+
child: Container(
344+
constraints: const BoxConstraints(maxWidth: 600),
345+
decoration: BoxDecoration(
346+
color: isFromUser
347+
? Theme.of(context).colorScheme.primaryContainer
348+
: Theme.of(context).colorScheme.surfaceVariant,
349+
borderRadius: BorderRadius.circular(18),
350+
),
351+
padding: const EdgeInsets.symmetric(
352+
vertical: 15,
353+
horizontal: 20,
354+
),
355+
margin: const EdgeInsets.only(bottom: 8),
356+
child: Column(children: [
357+
if (text case final text?) MarkdownBody(data: text),
358+
if (image case final image?) image,
359+
]))),
287360
],
288361
);
289362
}

samples/flutter_app/macos/Runner/DebugProfile.entitlements

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
<key>com.apple.security.network.server</key>
1010
<true/>
1111
<key>com.apple.security.network.client</key>
12-
<true/>
12+
<true/>
1313
</dict>
1414
</plist>

samples/flutter_app/pubspec.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ dev_dependencies:
2222

2323
flutter:
2424
uses-material-design: true
25+
assets:
26+
- assets/images/
2527

2628
# Note: this section is only used in order to resolve google_generative_ai to
2729
# the same repo as the sample.

0 commit comments

Comments
 (0)