diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py
index 3ec382d0f3d..c825233aca2 100644
--- a/packages/python/plotly/plotly/express/__init__.py
+++ b/packages/python/plotly/plotly/express/__init__.py
@@ -39,6 +39,11 @@
     choropleth,
     density_contour,
     density_heatmap,
+    pie,
+    sunburst,
+    treemap,
+    funnel,
+    funnel_area,
 )
 
 from ._imshow import imshow
@@ -77,6 +82,11 @@
     "strip",
     "histogram",
     "choropleth",
+    "pie",
+    "sunburst",
+    "treemap",
+    "funnel",
+    "funnel_area",
     "imshow",
     "data",
     "colors",
diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py
index cbf2fff85cd..8cbd4d85b65 100644
--- a/packages/python/plotly/plotly/express/_chart_types.py
+++ b/packages/python/plotly/plotly/express/_chart_types.py
@@ -1115,3 +1115,208 @@ def parallel_categories(
 
 
 parallel_categories.__doc__ = make_docstring(parallel_categories)
+
+
+def pie(
+    data_frame=None,
+    names=None,
+    values=None,
+    color=None,
+    color_discrete_sequence=None,
+    color_discrete_map={},
+    hover_name=None,
+    hover_data=None,
+    custom_data=None,
+    labels={},
+    title=None,
+    template=None,
+    width=None,
+    height=None,
+    opacity=None,
+    hole=None,
+):
+    """
+    In a pie plot, each row of `data_frame` is represented as a sector of a pie.
+    """
+    if color_discrete_sequence is not None:
+        layout_patch = {"piecolorway": color_discrete_sequence}
+    else:
+        layout_patch = {}
+    return make_figure(
+        args=locals(),
+        constructor=go.Pie,
+        trace_patch=dict(showlegend=(names is not None), hole=hole),
+        layout_patch=layout_patch,
+    )
+
+
+pie.__doc__ = make_docstring(
+    pie,
+    override_dict=dict(
+        hole=[
+            "float",
+            "Sets the fraction of the radius to cut out of the pie."
+            "Use this to make a donut chart.",
+        ],
+    ),
+)
+
+
+def sunburst(
+    data_frame=None,
+    names=None,
+    values=None,
+    parents=None,
+    ids=None,
+    color=None,
+    color_continuous_scale=None,
+    range_color=None,
+    color_continuous_midpoint=None,
+    color_discrete_sequence=None,
+    color_discrete_map={},
+    hover_name=None,
+    hover_data=None,
+    custom_data=None,
+    labels={},
+    title=None,
+    template=None,
+    width=None,
+    height=None,
+    branchvalues=None,
+    maxdepth=None,
+):
+    """
+    A sunburst plot represents hierarchial data as sectors laid out over
+    several levels of concentric rings.
+    """
+    if color_discrete_sequence is not None:
+        layout_patch = {"sunburstcolorway": color_discrete_sequence}
+    else:
+        layout_patch = {}
+    return make_figure(
+        args=locals(),
+        constructor=go.Sunburst,
+        trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
+        layout_patch=layout_patch,
+    )
+
+
+sunburst.__doc__ = make_docstring(sunburst)
+
+
+def treemap(
+    data_frame=None,
+    names=None,
+    values=None,
+    parents=None,
+    ids=None,
+    color=None,
+    color_continuous_scale=None,
+    range_color=None,
+    color_continuous_midpoint=None,
+    color_discrete_sequence=None,
+    color_discrete_map={},
+    hover_name=None,
+    hover_data=None,
+    custom_data=None,
+    labels={},
+    title=None,
+    template=None,
+    width=None,
+    height=None,
+    branchvalues=None,
+    maxdepth=None,
+):
+    """
+    A treemap plot represents hierarchial data as nested rectangular sectors.
+    """
+    if color_discrete_sequence is not None:
+        layout_patch = {"treemapcolorway": color_discrete_sequence}
+    else:
+        layout_patch = {}
+    return make_figure(
+        args=locals(),
+        constructor=go.Treemap,
+        trace_patch=dict(branchvalues=branchvalues, maxdepth=maxdepth),
+        layout_patch=layout_patch,
+    )
+
+
+treemap.__doc__ = make_docstring(treemap)
+
+
+def funnel(
+    data_frame=None,
+    x=None,
+    y=None,
+    color=None,
+    facet_row=None,
+    facet_col=None,
+    facet_col_wrap=0,
+    hover_name=None,
+    hover_data=None,
+    custom_data=None,
+    text=None,
+    animation_frame=None,
+    animation_group=None,
+    category_orders={},
+    labels={},
+    color_discrete_sequence=None,
+    color_discrete_map={},
+    opacity=None,
+    orientation="h",
+    log_x=False,
+    log_y=False,
+    range_x=None,
+    range_y=None,
+    title=None,
+    template=None,
+    width=None,
+    height=None,
+):
+    """
+    In a funnel plot, each row of `data_frame` is represented as a rectangular sector of a funnel.
+    """
+    return make_figure(
+        args=locals(),
+        constructor=go.Funnel,
+        trace_patch=dict(opacity=opacity, orientation=orientation),
+    )
+
+
+funnel.__doc__ = make_docstring(funnel)
+
+
+def funnel_area(
+    data_frame=None,
+    names=None,
+    values=None,
+    color=None,
+    color_discrete_sequence=None,
+    color_discrete_map={},
+    hover_name=None,
+    hover_data=None,
+    custom_data=None,
+    labels={},
+    title=None,
+    template=None,
+    width=None,
+    height=None,
+    opacity=None,
+):
+    """
+    In a funnel area plot, each row of `data_frame` is represented as a trapezoidal sector of a funnel.
+    """
+    if color_discrete_sequence is not None:
+        layout_patch = {"funnelareacolorway": color_discrete_sequence}
+    else:
+        layout_patch = {}
+    return make_figure(
+        args=locals(),
+        constructor=go.Funnelarea,
+        trace_patch=dict(showlegend=(names is not None)),
+        layout_patch=layout_patch,
+    )
+
+
+funnel_area.__doc__ = make_docstring(funnel_area)
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 503edaf7b9c..923e6ea3dfc 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -291,6 +291,28 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
                     result["z"] = g[v]
                     result["coloraxis"] = "coloraxis1"
                     mapping_labels[v_label] = "%{z}"
+                elif trace_spec.constructor in [
+                    go.Sunburst,
+                    go.Treemap,
+                    go.Pie,
+                    go.Funnelarea,
+                ]:
+                    if "marker" not in result:
+                        result["marker"] = dict()
+
+                    if args.get("color_is_continuous"):
+                        result["marker"]["colors"] = g[v]
+                        result["marker"]["coloraxis"] = "coloraxis1"
+                        mapping_labels[v_label] = "%{color}"
+                    else:
+                        result["marker"]["colors"] = []
+                        mapping = {}
+                        for cat in g[v]:
+                            if mapping.get(cat) is None:
+                                mapping[cat] = args["color_discrete_sequence"][
+                                    len(mapping) % len(args["color_discrete_sequence"])
+                                ]
+                            result["marker"]["colors"].append(mapping[cat])
                 else:
                     colorable = "marker"
                     if trace_spec.constructor in [go.Parcats, go.Parcoords]:
@@ -305,11 +327,38 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
             elif k == "locations":
                 result[k] = g[v]
                 mapping_labels[v_label] = "%{location}"
+            elif k == "values":
+                result[k] = g[v]
+                _label = "value" if v_label == "values" else v_label
+                mapping_labels[_label] = "%{value}"
+            elif k == "parents":
+                result[k] = g[v]
+                _label = "parent" if v_label == "parents" else v_label
+                mapping_labels[_label] = "%{parent}"
+            elif k == "ids":
+                result[k] = g[v]
+                _label = "id" if v_label == "ids" else v_label
+                mapping_labels[_label] = "%{id}"
+            elif k == "names":
+                if trace_spec.constructor in [
+                    go.Sunburst,
+                    go.Treemap,
+                    go.Pie,
+                    go.Funnelarea,
+                ]:
+                    result["labels"] = g[v]
+                    _label = "label" if v_label == "names" else v_label
+                    mapping_labels[_label] = "%{label}"
+                else:
+                    result[k] = g[v]
             else:
                 if v:
                     result[k] = g[v]
                 mapping_labels[v_label] = "%%{%s}" % k
-    if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
+    if trace_spec.constructor not in [
+        go.Parcoords,
+        go.Parcats,
+    ]:
         hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
         result["hovertemplate"] = hover_header + "<br>".join(hover_lines)
     return result, fit_results
@@ -674,6 +723,7 @@ def one_group(x):
 
 def apply_default_cascade(args):
     # first we apply px.defaults to unspecified args
+
     for param in (
         ["color_discrete_sequence", "color_continuous_scale"]
         + ["symbol_sequence", "line_dash_sequence", "template"]
@@ -956,6 +1006,7 @@ def infer_config(args, constructor, trace_patch):
     attrables = (
         ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
         + ["custom_data", "hover_name", "hover_data", "text"]
+        + ["names", "values", "parents", "ids"]
         + ["error_x", "error_x_minus"]
         + ["error_y", "error_y_minus", "error_z", "error_z_minus"]
         + ["lat", "lon", "locations", "animation_group"]
@@ -989,14 +1040,34 @@ def infer_config(args, constructor, trace_patch):
                     and args["data_frame"][args["color"]].dtype.kind in "bifc"
                 ):
                     attrs.append("color")
+                    args["color_is_continuous"] = True
+                elif constructor in [go.Sunburst, go.Treemap]:
+                    attrs.append("color")
+                    args["color_is_continuous"] = False
                 else:
                     grouped_attrs.append("marker.color")
         elif "line_group" in args or constructor == go.Histogram2dContour:
             grouped_attrs.append("line.color")
+        elif constructor in [go.Pie, go.Funnelarea]:
+            attrs.append("color")
+            if args["color"]:
+                if args["hover_data"] is None:
+                    args["hover_data"] = []
+                args["hover_data"].append(args["color"])
         else:
             grouped_attrs.append("marker.color")
 
-        show_colorbar = bool("color" in attrs and args["color"])
+        show_colorbar = bool(
+            "color" in attrs
+            and args["color"]
+            and constructor not in [go.Pie, go.Funnelarea]
+            and (
+                constructor not in [go.Treemap, go.Sunburst]
+                or args.get("color_is_continuous")
+            )
+        )
+    else:
+        show_colorbar = False
 
     # Compute line_dash grouping attribute
     if "line_dash" in args:
@@ -1148,6 +1219,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
                 go.Parcoords,
                 go.Choropleth,
                 go.Histogram2d,
+                go.Sunburst,
+                go.Treemap,
             ]:
                 trace.update(
                     legendgroup=trace_name,
diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py
index 950a3953be7..8948e6b321d 100644
--- a/packages/python/plotly/plotly/express/_doc.py
+++ b/packages/python/plotly/plotly/express/_doc.py
@@ -66,6 +66,21 @@
         colref_desc,
         "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.",
     ],
+    values=[
+        colref_type,
+        colref_desc,
+        "Values from this column or array_like are used to set values associated to sectors.",
+    ],
+    parents=[
+        colref_type,
+        colref_desc,
+        "Values from this column or array_like are used as parents in sunburst and treemap charts.",
+    ],
+    ids=[
+        colref_type,
+        colref_desc,
+        "Values from this column or array_like are used to set ids of sectors",
+    ],
     lat=[
         colref_type,
         colref_desc,
@@ -168,6 +183,11 @@
         colref_desc,
         "Values from this column or array_like appear in the figure as text labels.",
     ],
+    names=[
+        colref_type,
+        colref_desc,
+        "Values from this column or array_like are used as labels for sectors.",
+    ],
     locationmode=[
         "str",
         "One of 'ISO-3', 'USA-states', or 'country names'",
@@ -442,21 +462,41 @@
     nbins=["int", "Positive integer.", "Sets the number of bins."],
     nbinsx=["int", "Positive integer.", "Sets the number of bins along the x axis."],
     nbinsy=["int", "Positive integer.", "Sets the number of bins along the y axis."],
+    branchvalues=[
+        "str",
+        "'total' or 'remainder'",
+        "Determines how the items in `values` are summed. When"
+        "set to 'total', items in `values` are taken to be value"
+        "of all its descendants. When set to 'remainder', items"
+        "in `values` corresponding to the root and the branches"
+        ":sectors are taken to be the extra part not part of the"
+        "sum of the values at their leaves.",
+    ],
+    maxdepth=[
+        "int",
+        "Positive integer",
+        "Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
+        "levels in the hierarchy.",
+    ],
 )
 
 
-def make_docstring(fn):
+def make_docstring(fn, override_dict={}):
     tw = TextWrapper(width=77, initial_indent="    ", subsequent_indent="    ")
     result = (fn.__doc__ or "") + "\nParameters\n----------\n"
     for param in inspect.getargspec(fn)[0]:
-        param_desc_list = docs[param][1:]
+        if override_dict.get(param):
+            param_doc = override_dict[param]
+        else:
+            param_doc = docs[param]
+        param_desc_list = param_doc[1:]
         param_desc = (
             tw.fill(" ".join(param_desc_list or ""))
             if param in docs
             else "(documentation missing from map)"
         )
 
-        param_type = docs[param][0]
+        param_type = param_doc[0]
         result += "%s: %s\n%s\n" % (param, param_type, param_desc)
     result += "\nReturns\n-------\n"
     result += "    A `Figure` object."
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py
new file mode 100644
index 00000000000..339accf9d57
--- /dev/null
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py
@@ -0,0 +1,141 @@
+import plotly.express as px
+import plotly.graph_objects as go
+from numpy.testing import assert_array_equal
+import numpy as np
+
+
+def _compare_figures(go_trace, px_fig):
+    """Compare a figure created with a go trace and a figure created with
+    a px function call. Check that all values inside the go Figure are the
+    same in the px figure (which sets more parameters).
+    """
+    go_fig = go.Figure(go_trace)
+    go_fig = go_fig.to_plotly_json()
+    px_fig = px_fig.to_plotly_json()
+    del go_fig["layout"]["template"]
+    del px_fig["layout"]["template"]
+    for key in go_fig["data"][0]:
+        assert_array_equal(go_fig["data"][0][key], px_fig["data"][0][key])
+    for key in go_fig["layout"]:
+        assert go_fig["layout"][key] == px_fig["layout"][key]
+
+
+def test_pie_like_px():
+    # Pie
+    labels = ["Oxygen", "Hydrogen", "Carbon_Dioxide", "Nitrogen"]
+    values = [4500, 2500, 1053, 500]
+
+    fig = px.pie(names=labels, values=values)
+    trace = go.Pie(labels=labels, values=values)
+    _compare_figures(trace, fig)
+
+    labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"]
+    parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"]
+    values = [10, 14, 12, 10, 2, 6, 6, 4, 4]
+    # Sunburst
+    fig = px.sunburst(names=labels, parents=parents, values=values)
+    trace = go.Sunburst(labels=labels, parents=parents, values=values)
+    _compare_figures(trace, fig)
+    # Treemap
+    fig = px.treemap(names=labels, parents=parents, values=values)
+    trace = go.Treemap(labels=labels, parents=parents, values=values)
+    _compare_figures(trace, fig)
+
+    # Funnel
+    x = ["A", "B", "C"]
+    y = [3, 2, 1]
+    fig = px.funnel(y=y, x=x)
+    trace = go.Funnel(y=y, x=x)
+    _compare_figures(trace, fig)
+    # Funnelarea
+    fig = px.funnel_area(values=y, names=x)
+    trace = go.Funnelarea(values=y, labels=x)
+    _compare_figures(trace, fig)
+
+
+def test_sunburst_treemap_colorscales():
+    labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"]
+    parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"]
+    values = [10, 14, 12, 10, 2, 6, 6, 4, 4]
+    for func, colorway in zip(
+        [px.sunburst, px.treemap], ["sunburstcolorway", "treemapcolorway"]
+    ):
+        # Continuous colorscale
+        fig = func(
+            names=labels,
+            parents=parents,
+            values=values,
+            color=values,
+            color_continuous_scale="Viridis",
+            range_color=(5, 15),
+        )
+        assert fig.layout.coloraxis.cmin, fig.layout.coloraxis.cmax == (5, 15)
+        # Discrete colorscale, color arg passed
+        color_seq = px.colors.sequential.Reds
+        fig = func(
+            names=labels,
+            parents=parents,
+            values=values,
+            color=labels,
+            color_discrete_sequence=color_seq,
+        )
+        assert np.all([col in color_seq for col in fig.data[0].marker.colors])
+        # Numerical color arg passed, fall back to continuous
+        fig = func(names=labels, parents=parents, values=values, color=values,)
+        assert [
+            el[0] == px.colors.sequential.Viridis
+            for i, el in enumerate(fig.layout.coloraxis.colorscale)
+        ]
+        # Numerical color arg passed, continuous colorscale
+        # even if color_discrete_sequence if passed
+        fig = func(
+            names=labels,
+            parents=parents,
+            values=values,
+            color=values,
+            color_discrete_sequence=color_seq,
+        )
+        assert [
+            el[0] == px.colors.sequential.Viridis
+            for i, el in enumerate(fig.layout.coloraxis.colorscale)
+        ]
+
+        # Discrete colorscale, no color arg passed
+        color_seq = px.colors.sequential.Reds
+        fig = func(
+            names=labels,
+            parents=parents,
+            values=values,
+            color_discrete_sequence=color_seq,
+        )
+        assert list(fig.layout[colorway]) == color_seq
+
+
+def test_pie_funnelarea_colorscale():
+    labels = ["A", "B", "C", "D"]
+    values = [3, 2, 1, 4]
+    for func, colorway in zip(
+        [px.sunburst, px.treemap], ["sunburstcolorway", "treemapcolorway"]
+    ):
+        # Discrete colorscale, no color arg passed
+        color_seq = px.colors.sequential.Reds
+        fig = func(names=labels, values=values, color_discrete_sequence=color_seq,)
+        assert list(fig.layout[colorway]) == color_seq
+        # Discrete colorscale, color arg passed
+        color_seq = px.colors.sequential.Reds
+        fig = func(
+            names=labels,
+            values=values,
+            color=labels,
+            color_discrete_sequence=color_seq,
+        )
+        assert np.all([col in color_seq for col in fig.data[0].marker.colors])
+
+
+def test_funnel():
+    fig = px.funnel(
+        x=[5, 4, 3, 3, 2, 1],
+        y=["A", "B", "C", "A", "B", "C"],
+        color=["0", "0", "0", "1", "1", "1"],
+    )
+    assert len(fig.data) == 2