Skip to content

Custom trainer editor analytics #5511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ml-agents-envs/mlagents_envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from mlagents_envs.logging_util import get_logger
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.default_training_analytics_side_channel import (
DefaultTrainingAnalyticsSideChannel,
)
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager
from mlagents_envs import env_utils

Expand Down Expand Up @@ -186,6 +189,16 @@ def __init__(
self._timeout_wait: int = timeout_wait
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait)
self._worker_id = worker_id
if side_channels is None:
side_channels = []
default_training_side_channel: Optional[
DefaultTrainingAnalyticsSideChannel
] = None
if "TrainingAnalyticsSideChannel" not in [
x.__class__.__name__ for x in side_channels
]:
default_training_side_channel = DefaultTrainingAnalyticsSideChannel()
side_channels.append(default_training_side_channel)
self._side_channel_manager = SideChannelManager(side_channels)
self._log_folder = log_folder
self.academy_capabilities: UnityRLCapabilitiesProto = None # type: ignore
Expand Down Expand Up @@ -246,6 +259,8 @@ def __init__(
self._is_first_message = True
self._update_behavior_specs(aca_output)
self.academy_capabilities = aca_params.capabilities
if default_training_side_channel is not None:
default_training_side_channel.environment_initialized()

@staticmethod
def _get_communicator(worker_id, base_port, timeout_wait):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import uuid
import mlagents_envs

from mlagents_envs.exception import UnityCommunicationException
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.communicator_objects.training_analytics_pb2 import (
TrainingEnvironmentInitialized,
)
from google.protobuf.any_pb2 import Any


class DefaultTrainingAnalyticsSideChannel(SideChannel):
"""
Side channel that sends information about the training to the Unity environment so it can be logged.
"""

def __init__(self) -> None:
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel")
# UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356')
# We purposefully use the SAME side channel as the TrainingAnalyticsSideChannel
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))

def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The DefaultTrainingAnalyticsSideChannel received a message from Unity, "
+ "this should not have happened."
)

def environment_initialized(self) -> None:
# Tuple of (major, minor, patch)
vi = sys.version_info

msg = TrainingEnvironmentInitialized(
python_version=f"{vi[0]}.{vi[1]}.{vi[2]}",
mlagents_version="Custom",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main piece. "Custom" will be sent instead of the mlagents version.

mlagents_envs_version=mlagents_envs.__version__,
torch_version="Unknown",
torch_device_type="Unknown",
num_envs=0,
num_environment_parameters=0,
)

any_message = Any()
any_message.Pack(msg)

env_init_msg = OutgoingMessage()
env_init_msg.set_raw_bytes(any_message.SerializeToString()) # type: ignore
super().queue_message_to_send(env_init_msg)