Skip to content

Commit 3a9e0bc

Browse files
cwstrykervirtuald
andcommitted
Ported the ProfiledPIDSubsystem from the wpilib java source to Python
Co-authored-by: Dustin Spicuzza <[email protected]>
1 parent bffeb3a commit 3a9e0bc

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed

Diff for: commands2/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .pidcommand import PIDCommand
1717
from .pidsubsystem import PIDSubsystem
1818
from .printcommand import PrintCommand
19+
from .profiledpidsubsystem import ProfiledPIDSubsystem
1920
from .proxycommand import ProxyCommand
2021
from .repeatcommand import RepeatCommand
2122
from .runcommand import RunCommand
@@ -51,6 +52,7 @@
5152
"PIDCommand",
5253
"PIDSubsystem",
5354
"PrintCommand",
55+
"ProfiledPIDSubsystem",
5456
"ProxyCommand",
5557
"RepeatCommand",
5658
"RunCommand",

Diff for: commands2/profiledpidsubsystem.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) FIRST and other WPILib contributors.
2+
# Open Source Software; you can modify and/or share it under the terms of
3+
# the WPILib BSD license file in the root directory of this project.
4+
5+
from typing import Union, cast
6+
7+
from wpimath.trajectory import TrapezoidProfile
8+
9+
from .subsystem import Subsystem
10+
11+
12+
class ProfiledPIDSubsystem(Subsystem):
13+
"""
14+
A subsystem that uses a :class:`wpimath.controller.ProfiledPIDController`
15+
or :class:`wpimath.controller.ProfiledPIDControllerRadians` to
16+
control an output. The controller is run synchronously from the subsystem's
17+
:meth:`.periodic` method.
18+
"""
19+
20+
def __init__(
21+
self,
22+
controller,
23+
initial_position: float = 0,
24+
):
25+
"""Creates a new PIDSubsystem."""
26+
super().__init__()
27+
self._controller = controller
28+
self._enabled = False
29+
self.setGoal(initial_position)
30+
31+
def periodic(self):
32+
"""Updates the output of the controller."""
33+
if self._enabled:
34+
self.useOutput(
35+
self._controller.calculate(self.getMeasurement()),
36+
self._controller.getSetpoint(),
37+
)
38+
39+
def getController(
40+
self,
41+
):
42+
"""Returns the controller"""
43+
return self._controller
44+
45+
def setGoal(self, goal):
46+
"""
47+
Sets the goal state for the subsystem.
48+
"""
49+
self._controller.setGoal(goal)
50+
51+
def useOutput(self, output: float, setpoint: TrapezoidProfile.State):
52+
"""
53+
Uses the output from the controller object.
54+
"""
55+
raise NotImplementedError(f"{self.__class__} must implement useOutput")
56+
57+
def getMeasurement(self) -> float:
58+
"""
59+
Returns the measurement of the process variable used by the
60+
controller object.
61+
"""
62+
raise NotImplementedError(f"{self.__class__} must implement getMeasurement")
63+
64+
def enable(self):
65+
"""Enables the PID control. Resets the controller."""
66+
self._enabled = True
67+
self._controller.reset(self.getMeasurement())
68+
69+
def disable(self):
70+
"""Disables the PID control. Sets output to zero."""
71+
self._enabled = False
72+
self.useOutput(0, TrapezoidProfile.State())
73+
74+
def isEnabled(self) -> bool:
75+
"""
76+
Returns whether the controller is enabled.
77+
"""
78+
return self._enabled

Diff for: tests/test_profiledpidsubsystem.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from types import MethodType
2+
from typing import Any
3+
4+
import pytest
5+
from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
6+
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians
7+
8+
from commands2 import ProfiledPIDSubsystem
9+
10+
MAX_VELOCITY = 30 # Radians per second
11+
MAX_ACCELERATION = 500 # Radians per sec squared
12+
PID_KP = 50
13+
14+
15+
class EvalSubsystem(ProfiledPIDSubsystem):
16+
def __init__(self, controller, state_factory):
17+
self._state_factory = state_factory
18+
super().__init__(controller, 0)
19+
20+
21+
def simple_use_output(self, output: float, setpoint: Any):
22+
"""A simple useOutput method that saves the current state of the controller."""
23+
self._output = output
24+
self._setpoint = setpoint
25+
26+
27+
def simple_get_measurement(self) -> float:
28+
"""A simple getMeasurement method that returns zero (frozen or stuck plant)."""
29+
return 0.0
30+
31+
32+
controller_types = [
33+
(
34+
ProfiledPIDControllerRadians,
35+
TrapezoidProfileRadians.Constraints,
36+
TrapezoidProfileRadians.State,
37+
),
38+
(ProfiledPIDController, TrapezoidProfile.Constraints, TrapezoidProfile.State),
39+
]
40+
controller_ids = ["radians", "dimensionless"]
41+
42+
43+
@pytest.fixture(params=controller_types, ids=controller_ids)
44+
def subsystem(request):
45+
"""
46+
Fixture that returns an EvalSubsystem object for each type of controller.
47+
"""
48+
controller, profile_factory, state_factory = request.param
49+
profile = profile_factory(MAX_VELOCITY, MAX_ACCELERATION)
50+
pid = controller(PID_KP, 0, 0, profile)
51+
return EvalSubsystem(pid, state_factory)
52+
53+
54+
def test_profiled_pid_subsystem_init(subsystem):
55+
"""
56+
Verify that the ProfiledPIDSubsystem can be initialized using
57+
all supported profiled PID controller / trapezoid profile types.
58+
"""
59+
assert isinstance(subsystem, EvalSubsystem)
60+
61+
62+
def test_profiled_pid_subsystem_not_implemented_get_measurement(subsystem):
63+
"""
64+
Verify that the ProfiledPIDSubsystem.getMeasurement method
65+
raises NotImplementedError.
66+
"""
67+
with pytest.raises(NotImplementedError):
68+
subsystem.getMeasurement()
69+
70+
71+
def test_profiled_pid_subsystem_not_implemented_use_output(subsystem):
72+
"""
73+
Verify that the ProfiledPIDSubsystem.useOutput method raises
74+
NotImplementedError.
75+
"""
76+
with pytest.raises(NotImplementedError):
77+
subsystem.useOutput(0, subsystem._state_factory())
78+
79+
80+
@pytest.mark.parametrize("use_float", [True, False])
81+
def test_profiled_pid_subsystem_set_goal(subsystem, use_float):
82+
"""
83+
Verify that the ProfiledPIDSubsystem.setGoal method sets the goal.
84+
"""
85+
if use_float:
86+
subsystem.setGoal(1.0)
87+
assert subsystem.getController().getGoal().position == 1.0
88+
assert subsystem.getController().getGoal().velocity == 0.0
89+
else:
90+
subsystem.setGoal(subsystem._state_factory(1.0, 2.0))
91+
assert subsystem.getController().getGoal().position == 1.0
92+
assert subsystem.getController().getGoal().velocity == 2.0
93+
94+
95+
def test_profiled_pid_subsystem_enable_subsystem(subsystem):
96+
"""
97+
Verify the subsystem can be enabled.
98+
"""
99+
# Dynamically add useOutput and getMeasurement methods so the
100+
# system can be enabled
101+
setattr(subsystem, "useOutput", MethodType(simple_use_output, subsystem))
102+
setattr(subsystem, "getMeasurement", MethodType(simple_get_measurement, subsystem))
103+
# Enable the subsystem
104+
subsystem.enable()
105+
assert subsystem.isEnabled()
106+
107+
108+
def test_profiled_pid_subsystem_disable_subsystem(subsystem):
109+
"""
110+
Verify the subsystem can be disabled.
111+
"""
112+
# Dynamically add useOutput and getMeasurement methods so the
113+
# system can be enabled
114+
setattr(subsystem, "useOutput", MethodType(simple_use_output, subsystem))
115+
setattr(subsystem, "getMeasurement", MethodType(simple_get_measurement, subsystem))
116+
# Enable and then disable the subsystem
117+
subsystem.enable()
118+
subsystem.disable()
119+
assert not subsystem.isEnabled()

0 commit comments

Comments
 (0)