Skip to content

Commit 8b66735

Browse files
committed
Move shaders into pipeline description and unify base pipelines
1 parent 3837298 commit 8b66735

File tree

13 files changed

+80
-83
lines changed

13 files changed

+80
-83
lines changed

examples/hal/compute/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ fn main() {
6464
let entry_point = pso::EntryPoint { entry: "main", module: &shader };
6565
let pipeline = gpu.device
6666
.create_compute_pipelines(&[
67-
(entry_point, pso::ComputePipelineDesc::new(&pipeline_layout))
67+
pso::ComputePipelineDesc::new(entry_point, &pipeline_layout)
6868
])
6969
.remove(0)
7070
.expect("Error creating compute pipeline!");

examples/hal/quad/main.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ fn main() {
232232
let subpass = Subpass { index: 0, main_pass: &render_pass };
233233

234234
let mut pipeline_desc = pso::GraphicsPipelineDesc::new(
235+
shader_entries,
235236
Primitive::TriangleList,
236237
pso::Rasterizer::FILL,
237238
&pipeline_layout,
@@ -264,9 +265,7 @@ fn main() {
264265
});
265266

266267

267-
device.create_graphics_pipelines(&[
268-
(shader_entries, pipeline_desc)
269-
])
268+
device.create_graphics_pipelines(&[pipeline_desc])
270269
};
271270

272271
println!("pipelines: {:?}", pipelines);

src/backend/dx12/src/device.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -765,9 +765,9 @@ impl d::Device<B> for Device {
765765

766766
fn create_graphics_pipelines<'a>(
767767
&self,
768-
descs: &[(pso::GraphicsShaderSet<'a, B>, pso::GraphicsPipelineDesc<'a, B>)],
768+
descs: &[pso::GraphicsPipelineDesc<'a, B>],
769769
) -> Vec<Result<n::GraphicsPipeline, pso::CreationError>> {
770-
descs.iter().map(|&(shaders, ref desc)| {
770+
descs.iter().map(|desc| {
771771
let build_shader = |source: Option<pso::EntryPoint<'a, B>>| {
772772
// TODO: better handle case where looking up shader fails
773773
let shader = source.and_then(|src| src.module.shaders.get(src.entry));
@@ -787,11 +787,11 @@ impl d::Device<B> for Device {
787787
}
788788
};
789789

790-
let vs = build_shader(Some(shaders.vertex));
791-
let fs = build_shader(shaders.fragment);
792-
let gs = build_shader(shaders.geometry);
793-
let ds = build_shader(shaders.domain);
794-
let hs = build_shader(shaders.hull);
790+
let vs = build_shader(Some(desc.shaders.vertex));
791+
let fs = build_shader(desc.shaders.fragment);
792+
let gs = build_shader(desc.shaders.geometry);
793+
let ds = build_shader(desc.shaders.domain);
794+
let hs = build_shader(desc.shaders.hull);
795795

796796
// Define input element descriptions
797797
let mut vs_reflect = shade::reflect_shader(&vs);
@@ -936,12 +936,12 @@ impl d::Device<B> for Device {
936936

937937
fn create_compute_pipelines<'a>(
938938
&self,
939-
descs: &[(pso::EntryPoint<'a, B>, pso::ComputePipelineDesc<'a, B>)],
939+
descs: &[pso::ComputePipelineDesc<'a, B>],
940940
) -> Vec<Result<n::ComputePipeline, pso::CreationError>> {
941-
descs.iter().map(|&(shader, ref desc)| {
941+
descs.iter().map(|desc| {
942942
let cs = {
943943
// TODO: better handle case where looking up shader fails
944-
match shader.module.shaders.get(shader.entry) {
944+
match desc.shader.module.shaders.get(desc.shader.entry) {
945945
Some(shader) => {
946946
winapi::D3D12_SHADER_BYTECODE {
947947
pShaderBytecode: unsafe { (**shader).GetBufferPointer() as *const _ },

src/backend/empty/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ impl hal::Device<Backend> for Device {
9797

9898
fn create_graphics_pipelines<'a>(
9999
&self,
100-
_: &[(pso::GraphicsShaderSet<'a, Backend>, pso::GraphicsPipelineDesc<'a, Backend>)],
100+
_: &[pso::GraphicsPipelineDesc<'a, Backend>],
101101
) -> Vec<Result<(), pso::CreationError>> {
102102
unimplemented!()
103103
}
104104

105105
fn create_compute_pipelines<'a>(
106106
&self,
107-
_: &[(pso::EntryPoint<'a, Backend>, pso::ComputePipelineDesc<'a, Backend>)],
107+
_: &[pso::ComputePipelineDesc<'a, Backend>],
108108
) -> Vec<Result<(), pso::CreationError>> {
109109
unimplemented!()
110110
}

src/backend/gl/src/device.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,13 @@ impl d::Device<B> for Device {
233233

234234
fn create_graphics_pipelines<'a>(
235235
&self,
236-
descs: &[(pso::GraphicsShaderSet<'a, B>, pso::GraphicsPipelineDesc<'a, B>)],
236+
descs: &[pso::GraphicsPipelineDesc<'a, B>],
237237
) -> Vec<Result<n::GraphicsPipeline, pso::CreationError>> {
238238
let gl = &self.share.context;
239239
let priv_caps = &self.share.private_caps;
240240
let share = &self.share;
241241
descs.iter()
242-
.map(|&(shaders, ref desc)| {
242+
.map(|desc| {
243243
let subpass = {
244244
let subpass = desc.subpass;
245245
match subpass.main_pass.subpasses.get(subpass.index) {
@@ -259,11 +259,11 @@ impl d::Device<B> for Device {
259259
};
260260

261261
// Attach shaders to program
262-
attach_shader(Some(shaders.vertex));
263-
attach_shader(shaders.hull);
264-
attach_shader(shaders.domain);
265-
attach_shader(shaders.geometry);
266-
attach_shader(shaders.fragment);
262+
attach_shader(Some(desc.shaders.vertex));
263+
attach_shader(desc.shaders.hull);
264+
attach_shader(desc.shaders.domain);
265+
attach_shader(desc.shaders.geometry);
266+
attach_shader(desc.shaders.fragment);
267267

268268
if !priv_caps.program_interface && priv_caps.frag_data_location {
269269
for i in 0..subpass.color_attachments.len() {
@@ -302,7 +302,7 @@ impl d::Device<B> for Device {
302302

303303
fn create_compute_pipelines<'a>(
304304
&self,
305-
_descs: &[(pso::EntryPoint<'a, B>, pso::ComputePipelineDesc<'a, B>)],
305+
_descs: &[pso::ComputePipelineDesc<'a, B>],
306306
) -> Vec<Result<n::ComputePipeline, pso::CreationError>> {
307307
unimplemented!()
308308
}

src/backend/metal/src/device.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ impl Device {
257257

258258
fn create_graphics_pipeline<'a>(
259259
&self,
260-
&(ref shader_set, ref pipeline_desc):
261-
&(pso::GraphicsShaderSet<'a, Backend>, pso::GraphicsPipelineDesc<'a, Backend>),
260+
pipeline_desc: &pso::GraphicsPipelineDesc<'a, Backend>,
262261
) -> Result<n::GraphicsPipeline, pso::CreationError> {
263262
let pipeline = metal::RenderPipelineDescriptor::new();
264263
let pipeline_layout = &pipeline_desc.layout;
@@ -277,21 +276,21 @@ impl Device {
277276
pipeline.set_input_primitive_topology(primitive_class);
278277

279278
// Shaders
280-
let vs_lib = match shader_set.vertex.module {
279+
let vs_lib = match pipeline_desc.shaders.vertex.module {
281280
&n::ShaderModule::Compiled(ref lib) => lib.to_owned(),
282281
&n::ShaderModule::Raw(ref data) => {
283282
//TODO: cache them all somewhere!
284283
self.compile_shader_library(data, &pipeline_layout.res_overrides).unwrap()
285284
},
286285
};
287286
let mtl_vertex_function = vs_lib
288-
.get_function(shader_set.vertex.entry)
287+
.get_function(pipeline_desc.shaders.vertex.entry)
289288
.ok_or_else(|| {
290289
error!("invalid vertex shader entry point");
291290
pso::CreationError::Other
292291
})?;
293292
pipeline.set_vertex_function(Some(&mtl_vertex_function));
294-
let fs_lib = if let Some(fragment_entry) = shader_set.fragment {
293+
let fs_lib = if let Some(fragment_entry) = pipeline_desc.shaders.fragment {
295294
let fs_lib = match fragment_entry.module {
296295
&n::ShaderModule::Compiled(ref lib) => lib.to_owned(),
297296
&n::ShaderModule::Raw(ref data) => {
@@ -309,15 +308,15 @@ impl Device {
309308
} else {
310309
None
311310
};
312-
if shader_set.hull.is_some() {
311+
if pipeline_desc.shaders.hull.is_some() {
313312
error!("Metal tesselation shaders are not supported");
314313
return Err(pso::CreationError::Other);
315314
}
316-
if shader_set.domain.is_some() {
315+
if pipeline_desc.shaders.domain.is_some() {
317316
error!("Metal tesselation shaders are not supported");
318317
return Err(pso::CreationError::Other);
319318
}
320-
if shader_set.geometry.is_some() {
319+
if pipeline_desc.shaders.geometry.is_some() {
321320
error!("Metal geometry shaders are not supported");
322321
return Err(pso::CreationError::Other);
323322
}
@@ -544,7 +543,7 @@ impl hal::Device<Backend> for Device {
544543

545544
fn create_graphics_pipelines<'a>(
546545
&self,
547-
params: &[(pso::GraphicsShaderSet<'a, Backend>, pso::GraphicsPipelineDesc<'a, Backend>)],
546+
params: &[pso::GraphicsPipelineDesc<'a, Backend>],
548547
) -> Vec<Result<n::GraphicsPipeline, pso::CreationError>> {
549548
let mut output = Vec::with_capacity(params.len());
550549
for param in params {
@@ -555,7 +554,7 @@ impl hal::Device<Backend> for Device {
555554

556555
fn create_compute_pipelines<'a>(
557556
&self,
558-
_pipelines: &[(pso::EntryPoint<'a, Backend>, pso::ComputePipelineDesc<'a, Backend>)],
557+
_pipelines: &[pso::ComputePipelineDesc<'a, Backend>],
559558
) -> Vec<Result<n::ComputePipeline, pso::CreationError>> {
560559
unimplemented!()
561560
}

src/backend/vulkan/src/device.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ impl d::Device<B> for Device {
249249

250250
fn create_graphics_pipelines<'a>(
251251
&self,
252-
descs: &[(pso::GraphicsShaderSet<'a, B>, pso::GraphicsPipelineDesc<'a, B>)],
252+
descs: &[pso::GraphicsPipelineDesc<'a, B>],
253253
) -> Vec<Result<n::GraphicsPipeline, pso::CreationError>> {
254254
debug!("create_graphics_pipelines {:?}", descs);
255255
// Store pipeline parameters to avoid stack usage
@@ -283,26 +283,26 @@ impl d::Device<B> for Device {
283283
}
284284
};
285285

286-
let infos = descs.iter().map(|&(shaders, ref desc)| {
286+
let infos = descs.iter().map(|desc| {
287287
let mut stages = Vec::new();
288288
// Vertex stage
289289
if true { //vertex shader is required
290-
stages.push(make_stage(vk::SHADER_STAGE_VERTEX_BIT, shaders.vertex));
290+
stages.push(make_stage(vk::SHADER_STAGE_VERTEX_BIT, desc.shaders.vertex));
291291
}
292292
// Pixel stage
293-
if let Some(entry) = shaders.fragment {
293+
if let Some(entry) = desc.shaders.fragment {
294294
stages.push(make_stage(vk::SHADER_STAGE_FRAGMENT_BIT, entry));
295295
}
296296
// Geometry stage
297-
if let Some(entry) = shaders.geometry {
297+
if let Some(entry) = desc.shaders.geometry {
298298
stages.push(make_stage(vk::SHADER_STAGE_GEOMETRY_BIT, entry));
299299
}
300300
// Domain stage
301-
if let Some(entry) = shaders.domain {
301+
if let Some(entry) = desc.shaders.domain {
302302
stages.push(make_stage(vk::SHADER_STAGE_TESSELLATION_EVALUATION_BIT, entry));
303303
}
304304
// Hull stage
305-
if let Some(entry) = shaders.hull {
305+
if let Some(entry) = desc.shaders.hull {
306306
stages.push(make_stage(vk::SHADER_STAGE_TESSELLATION_CONTROL_BIT, entry));
307307
}
308308

@@ -363,7 +363,7 @@ impl d::Device<B> for Device {
363363
p_next: ptr::null(),
364364
flags: vk::PipelineRasterizationStateCreateFlags::empty(),
365365
depth_clamp_enable: if desc.rasterizer.depth_clamping { vk::VK_TRUE } else { vk::VK_FALSE },
366-
rasterizer_discard_enable: if shaders.fragment.is_none() { vk::VK_TRUE } else { vk::VK_FALSE },
366+
rasterizer_discard_enable: if desc.shaders.fragment.is_none() { vk::VK_TRUE } else { vk::VK_FALSE },
367367
polygon_mode: polygon_mode,
368368
cull_mode: desc.rasterizer.cull_face.map(conv::map_cull_face).unwrap_or(vk::CULL_MODE_NONE),
369369
front_face: conv::map_front_face(desc.rasterizer.front_face),
@@ -374,7 +374,7 @@ impl d::Device<B> for Device {
374374
line_width: line_width,
375375
});
376376

377-
let is_tessellated = shaders.hull.is_some() && shaders.domain.is_some();
377+
let is_tessellated = desc.shaders.hull.is_some() && desc.shaders.domain.is_some();
378378
if is_tessellated {
379379
info_tessellation_states.push(vk::PipelineTessellationStateCreateInfo {
380380
s_type: vk::StructureType::PipelineTessellationStateCreateInfo,
@@ -484,13 +484,13 @@ impl d::Device<B> for Device {
484484
});
485485

486486
let (base_handle, base_index) = match desc.parent {
487-
pso::BaseGraphics::Pipeline(pipeline) => (pipeline.0, -1),
488-
pso::BaseGraphics::Index(index) => (vk::Pipeline::null(), index as _),
489-
pso::BaseGraphics::None => (vk::Pipeline::null(), -1),
487+
pso::BasePipeline::Pipeline(pipeline) => (pipeline.0, -1),
488+
pso::BasePipeline::Index(index) => (vk::Pipeline::null(), index as _),
489+
pso::BasePipeline::None => (vk::Pipeline::null(), -1),
490490
};
491491

492492
let mut flags = vk::PipelineCreateFlags::empty();
493-
if let pso::BaseGraphics::None = desc.parent {
493+
if let pso::BasePipeline::None = desc.parent {
494494
flags |= vk::PIPELINE_CREATE_DERIVATIVE_BIT;
495495
}
496496
if desc.flags.contains(pso::PipelineCreationFlags::DISABLE_OPTIMIZATION) {
@@ -557,11 +557,11 @@ impl d::Device<B> for Device {
557557

558558
fn create_compute_pipelines<'a>(
559559
&self,
560-
descs: &[(pso::EntryPoint<'a, B>, pso::ComputePipelineDesc<'a, B>)],
560+
descs: &[pso::ComputePipelineDesc<'a, B>],
561561
) -> Vec<Result<n::ComputePipeline, pso::CreationError>> {
562562
let mut c_strings = Vec::new(); // hold the C strings temporarily
563-
let infos = descs.iter().map(|&(entry_point, ref desc)| {
564-
let string = CString::new(entry_point.entry).unwrap();
563+
let infos = descs.iter().map(|desc| {
564+
let string = CString::new(desc.shader.entry).unwrap();
565565
let p_name = string.as_ptr();
566566
c_strings.push(string);
567567

@@ -570,19 +570,19 @@ impl d::Device<B> for Device {
570570
p_next: ptr::null(),
571571
flags: vk::PipelineShaderStageCreateFlags::empty(),
572572
stage: vk::SHADER_STAGE_COMPUTE_BIT,
573-
module: entry_point.module.raw,
573+
module: desc.shader.module.raw,
574574
p_name,
575575
p_specialization_info: ptr::null(),
576576
};
577577

578578
let (base_handle, base_index) = match desc.parent {
579-
pso::BaseCompute::Pipeline(pipeline) => (pipeline.0, -1),
580-
pso::BaseCompute::Index(index) => (vk::Pipeline::null(), index as _),
581-
pso::BaseCompute::None => (vk::Pipeline::null(), -1),
579+
pso::BasePipeline::Pipeline(pipeline) => (pipeline.0, -1),
580+
pso::BasePipeline::Index(index) => (vk::Pipeline::null(), index as _),
581+
pso::BasePipeline::None => (vk::Pipeline::null(), -1),
582582
};
583583

584584
let mut flags = vk::PipelineCreateFlags::empty();
585-
if let pso::BaseCompute::None = desc.parent {
585+
if let pso::BasePipeline::None = desc.parent {
586586
flags |= vk::PIPELINE_CREATE_DERIVATIVE_BIT;
587587
}
588588
if desc.flags.contains(pso::PipelineCreationFlags::DISABLE_OPTIMIZATION) {

src/hal/src/device.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ pub trait Device<B: Backend> {
154154
/// Create graphics pipelines.
155155
fn create_graphics_pipelines<'a>(
156156
&self,
157-
&[(pso::GraphicsShaderSet<'a, B>, pso::GraphicsPipelineDesc<'a, B>)],
157+
&[pso::GraphicsPipelineDesc<'a, B>],
158158
) -> Vec<Result<B::GraphicsPipeline, pso::CreationError>>;
159159

160160
/// Destroys a graphics pipeline.
@@ -166,7 +166,7 @@ pub trait Device<B: Backend> {
166166
/// Create compute pipelines.
167167
fn create_compute_pipelines<'a>(
168168
&self,
169-
&[(pso::EntryPoint<'a, B>, pso::ComputePipelineDesc<'a, B>)],
169+
&[pso::ComputePipelineDesc<'a, B>],
170170
) -> Vec<Result<B::ComputePipeline, pso::CreationError>>;
171171

172172
/// Destroys a compute pipeline.

src/hal/src/pso/compute.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
//! Compute pipeline descriptor.
22
33
use Backend;
4-
use super::{BaseCompute, PipelineCreationFlags};
4+
use super::{BaseCompute, BasePipeline, EntryPoint, PipelineCreationFlags};
55

66
///
77
#[derive(Debug)]
88
pub struct ComputePipelineDesc<'a, B: Backend> {
9+
///
10+
pub shader: EntryPoint<'a, B>,
911
/// Pipeline layout.
1012
pub layout: &'a B::PipelineLayout,
1113
///
@@ -17,12 +19,14 @@ pub struct ComputePipelineDesc<'a, B: Backend> {
1719
impl<'a, B: Backend> ComputePipelineDesc<'a, B> {
1820
/// Create a new empty PSO descriptor.
1921
pub fn new(
22+
shader: EntryPoint<'a, B>,
2023
layout: &'a B::PipelineLayout,
2124
) -> Self {
2225
ComputePipelineDesc {
26+
shader,
2327
layout,
2428
flags: PipelineCreationFlags::empty(),
25-
parent: BaseCompute::None,
29+
parent: BasePipeline::None,
2630
}
2731
}
2832
}

0 commit comments

Comments
 (0)