diff --git a/samples/flutter_app/assets/images/cat.jpg b/samples/flutter_app/assets/images/cat.jpg new file mode 100644 index 0000000..8d2069e Binary files /dev/null and b/samples/flutter_app/assets/images/cat.jpg differ diff --git a/samples/flutter_app/assets/images/scones.jpg b/samples/flutter_app/assets/images/scones.jpg new file mode 100644 index 0000000..ce68958 Binary files /dev/null and b/samples/flutter_app/assets/images/scones.jpg differ diff --git a/samples/flutter_app/lib/main.dart b/samples/flutter_app/lib/main.dart index 469dc70..ec7043a 100644 --- a/samples/flutter_app/lib/main.dart +++ b/samples/flutter_app/lib/main.dart @@ -13,6 +13,7 @@ // limitations under the License. import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; import 'package:flutter_markdown/flutter_markdown.dart'; import 'package:google_generative_ai/google_generative_ai.dart'; @@ -69,10 +70,13 @@ class ChatWidget extends StatefulWidget { class _ChatWidgetState extends State { late final GenerativeModel _model; + late final GenerativeModel _visionModel; late final ChatSession _chat; final ScrollController _scrollController = ScrollController(); final TextEditingController _textController = TextEditingController(); final FocusNode _textFieldFocus = FocusNode(); + final List<({Image? image, String? text, bool fromUser})> _generatedContent = + <({Image? image, String? text, bool fromUser})>[]; bool _loading = false; static const _apiKey = String.fromEnvironment('API_KEY'); @@ -83,6 +87,10 @@ class _ChatWidgetState extends State { model: 'gemini-pro', apiKey: _apiKey, ); + _visionModel = GenerativeModel( + model: 'gemini-pro-vision', + apiKey: _apiKey, + ); _chat = _model.startChat(); } @@ -132,17 +140,14 @@ class _ChatWidgetState extends State { ? ListView.builder( controller: _scrollController, itemBuilder: (context, idx) { - var content = _chat.history.toList()[idx]; - var text = content.parts - .whereType() - .map((e) => e.text) - .join(''); + var content = _generatedContent[idx]; return MessageWidget( - text: text, - isFromUser: content.role == 'user', + text: content.text, + image: content.image, + isFromUser: content.fromUser, ); }, - itemCount: _chat.history.length, + itemCount: _generatedContent.length, ) : ListView( children: const [ @@ -171,6 +176,19 @@ class _ChatWidgetState extends State { const SizedBox.square( dimension: 15, ), + IconButton( + onPressed: !_loading + ? () async { + _sendImagePrompt(_textController.text); + } + : null, + icon: Icon( + Icons.image, + color: _loading + ? Theme.of(context).colorScheme.secondary + : Theme.of(context).colorScheme.primary, + ), + ), if (!_loading) IconButton( onPressed: () async { @@ -191,16 +209,71 @@ class _ChatWidgetState extends State { ); } + Future _sendImagePrompt(String message) async { + setState(() { + _loading = true; + }); + try { + ByteData catBytes = await rootBundle.load('assets/images/cat.jpg'); + ByteData sconeBytes = await rootBundle.load('assets/images/scones.jpg'); + final content = [ + Content.multi([ + TextPart(message), + // The only accepted mime types are image/*. + DataPart('image/jpeg', catBytes.buffer.asUint8List()), + DataPart('image/jpeg', sconeBytes.buffer.asUint8List()), + ]) + ]; + _generatedContent.add(( + image: Image.asset("assets/images/cat.jpg"), + text: message, + fromUser: true + )); + _generatedContent.add(( + image: Image.asset("assets/images/scones.jpg"), + text: null, + fromUser: true + )); + + var response = await _visionModel.generateContent(content); + var text = response.text; + _generatedContent.add((image: null, text: text, fromUser: false)); + + if (text == null) { + _showError('No response from API.'); + return; + } else { + setState(() { + _loading = false; + _scrollDown(); + }); + } + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + Future _sendChatMessage(String message) async { setState(() { _loading = true; }); try { + _generatedContent.add((image: null, text: message, fromUser: true)); var response = await _chat.sendMessage( Content.text(message), ); var text = response.text; + _generatedContent.add((image: null, text: text, fromUser: false)); if (text == null) { _showError('No response from API.'); @@ -249,12 +322,14 @@ class _ChatWidgetState extends State { } class MessageWidget extends StatelessWidget { - final String text; + final Image? image; + final String? text; final bool isFromUser; const MessageWidget({ super.key, - required this.text, + this.image, + this.text, required this.isFromUser, }); @@ -265,25 +340,23 @@ class MessageWidget extends StatelessWidget { isFromUser ? MainAxisAlignment.end : MainAxisAlignment.start, children: [ Flexible( - child: Container( - constraints: const BoxConstraints(maxWidth: 600), - decoration: BoxDecoration( - color: isFromUser - ? Theme.of(context).colorScheme.primaryContainer - : Theme.of(context).colorScheme.surfaceVariant, - borderRadius: BorderRadius.circular(18), - ), - padding: const EdgeInsets.symmetric( - vertical: 15, - horizontal: 20, - ), - margin: const EdgeInsets.only(bottom: 8), - child: MarkdownBody( - selectable: true, - data: text, - ), - ), - ), + child: Container( + constraints: const BoxConstraints(maxWidth: 600), + decoration: BoxDecoration( + color: isFromUser + ? Theme.of(context).colorScheme.primaryContainer + : Theme.of(context).colorScheme.surfaceVariant, + borderRadius: BorderRadius.circular(18), + ), + padding: const EdgeInsets.symmetric( + vertical: 15, + horizontal: 20, + ), + margin: const EdgeInsets.only(bottom: 8), + child: Column(children: [ + if (text case final text?) MarkdownBody(data: text), + if (image case final image?) image, + ]))), ], ); } diff --git a/samples/flutter_app/macos/Runner/DebugProfile.entitlements b/samples/flutter_app/macos/Runner/DebugProfile.entitlements index 8cb7022..c946719 100644 --- a/samples/flutter_app/macos/Runner/DebugProfile.entitlements +++ b/samples/flutter_app/macos/Runner/DebugProfile.entitlements @@ -9,6 +9,6 @@ com.apple.security.network.server com.apple.security.network.client - + diff --git a/samples/flutter_app/pubspec.yaml b/samples/flutter_app/pubspec.yaml index 1cf6b31..f473e97 100644 --- a/samples/flutter_app/pubspec.yaml +++ b/samples/flutter_app/pubspec.yaml @@ -22,6 +22,8 @@ dev_dependencies: flutter: uses-material-design: true + assets: + - assets/images/ # Note: this section is only used in order to resolve google_generative_ai to # the same repo as the sample.