@@ -30,6 +30,7 @@ pub struct Device {
3030 device : ash:: Device ,
3131 memory_properties : vk:: PhysicalDeviceMemoryProperties ,
3232 buffers : Option < DeviceBuffers > ,
33+ direction : CorrelationDirection ,
3334 max_buffer_size : usize ,
3435 descriptor_sets : DescriptorSets ,
3536 pipelines : HashMap < ShaderModuleType , ShaderPipeline > ,
@@ -64,8 +65,7 @@ struct DeviceBuffers {
6465
6566struct DescriptorSets {
6667 descriptor_pool : vk:: DescriptorPool ,
67- regular_layout : vk:: DescriptorSetLayout ,
68- cross_check_layout : vk:: DescriptorSetLayout ,
68+ layout : vk:: DescriptorSetLayout ,
6969 pipeline_layout : vk:: PipelineLayout ,
7070 descriptor_sets : Vec < vk:: DescriptorSet > ,
7171}
@@ -229,6 +229,7 @@ impl Device {
229229 instance. destroy_instance ( None ) ;
230230 err
231231 } ;
232+ let direction = CorrelationDirection :: Forward ;
232233 // Init control struct - queues, fences, command buffer.
233234 let control =
234235 unsafe { Device :: create_control ( & device, compute_queue_index) . map_err ( cleanup_err) ? } ;
@@ -238,6 +239,7 @@ impl Device {
238239 device,
239240 memory_properties,
240241 buffers : None ,
242+ direction,
241243 max_buffer_size,
242244 descriptor_sets,
243245 pipelines,
@@ -709,27 +711,20 @@ impl Device {
709711 . ty ( vk:: DescriptorType :: STORAGE_BUFFER )
710712 . descriptor_count ( 6 ) ] ;
711713 let descriptor_pool_info = vk:: DescriptorPoolCreateInfo :: default ( )
712- . max_sets ( 2 )
714+ . max_sets ( 1 )
713715 . pool_sizes ( & descriptor_pool_size) ;
714716 let descriptor_pool = device. create_descriptor_pool ( & descriptor_pool_info, None ) ?;
715717 let cleanup_err = |err| {
716718 device. destroy_descriptor_pool ( descriptor_pool, None ) ;
717719 err
718720 } ;
719- let regular_layout = create_layout_bindings ( 6 ) . map_err ( cleanup_err) ?;
721+ let layout = create_layout_bindings ( 6 ) . map_err ( cleanup_err) ?;
720722 let cleanup_err = |err| {
721- device. destroy_descriptor_set_layout ( regular_layout , None ) ;
723+ device. destroy_descriptor_set_layout ( layout , None ) ;
722724 device. destroy_descriptor_pool ( descriptor_pool, None ) ;
723725 err
724726 } ;
725- let cross_check_layout = create_layout_bindings ( 2 ) . map_err ( cleanup_err) ?;
726- let cleanup_err = |err| {
727- device. destroy_descriptor_set_layout ( cross_check_layout, None ) ;
728- device. destroy_descriptor_set_layout ( regular_layout, None ) ;
729- device. destroy_descriptor_pool ( descriptor_pool, None ) ;
730- err
731- } ;
732- let layouts = [ regular_layout, cross_check_layout] ;
727+ let layouts = [ layout] ;
733728 let push_constant_ranges = vk:: PushConstantRange :: default ( )
734729 . offset ( 0 )
735730 . size ( std:: mem:: size_of :: < ShaderParams > ( ) as u32 )
@@ -744,8 +739,7 @@ impl Device {
744739 . map_err ( cleanup_err) ?;
745740 let cleanup_err = |err| {
746741 device. destroy_pipeline_layout ( pipeline_layout, None ) ;
747- device. destroy_descriptor_set_layout ( cross_check_layout, None ) ;
748- device. destroy_descriptor_set_layout ( regular_layout, None ) ;
742+ device. destroy_descriptor_set_layout ( layout, None ) ;
749743 device. destroy_descriptor_pool ( descriptor_pool, None ) ;
750744 err
751745 } ;
@@ -758,8 +752,7 @@ impl Device {
758752
759753 Ok ( DescriptorSets {
760754 descriptor_pool,
761- regular_layout,
762- cross_check_layout,
755+ layout,
763756 pipeline_layout,
764757 descriptor_sets,
765758 } )
@@ -836,6 +829,57 @@ impl Device {
836829 Ok ( result)
837830 }
838831
832+ fn set_buffer_layout ( & mut self , shader : & ShaderModuleType ) -> Result < ( ) , GpuError > {
833+ let direction = self . direction ;
834+ let descriptor_sets = & self . descriptor_sets ;
835+ let buffers = & self . buffers ( ) ?;
836+ let ( buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) =
837+ match direction {
838+ CorrelationDirection :: Forward => (
839+ buffers. buffer_internal_img1 ,
840+ buffers. buffer_internal_img2 ,
841+ buffers. buffer_out ,
842+ buffers. buffer_out_reverse ,
843+ ) ,
844+ CorrelationDirection :: Reverse => (
845+ buffers. buffer_internal_img2 ,
846+ buffers. buffer_internal_img1 ,
847+ buffers. buffer_out_reverse ,
848+ buffers. buffer_out ,
849+ ) ,
850+ } ;
851+ let buffer_list = if matches ! ( shader, ShaderModuleType :: CrossCheckFilter ) {
852+ vec ! [ buffer_out, buffer_out_reverse]
853+ } else {
854+ vec ! [
855+ buffers. buffer_img,
856+ buffer_internal_img1,
857+ buffer_internal_img2,
858+ buffers. buffer_internal_int,
859+ buffer_out,
860+ buffers. buffer_out_corr,
861+ ]
862+ } ;
863+ let buffer_infos = buffer_list
864+ . iter ( )
865+ . map ( |buf| {
866+ vk:: DescriptorBufferInfo :: default ( )
867+ . buffer ( buf. buffer )
868+ . offset ( 0 )
869+ . range ( vk:: WHOLE_SIZE )
870+ } )
871+ . collect :: < Vec < _ > > ( ) ;
872+ let write_descriptor = vk:: WriteDescriptorSet :: default ( )
873+ . dst_set ( descriptor_sets. descriptor_sets [ 0 ] )
874+ . dst_binding ( 0 )
875+ . descriptor_type ( vk:: DescriptorType :: STORAGE_BUFFER )
876+ . buffer_info ( buffer_infos. as_slice ( ) ) ;
877+ unsafe {
878+ self . device . update_descriptor_sets ( & [ write_descriptor] , & [ ] ) ;
879+ }
880+ Ok ( ( ) )
881+ }
882+
839883 unsafe fn create_control (
840884 device : & ash:: Device ,
841885 queue_family_index : u32 ,
@@ -876,57 +920,7 @@ impl Device {
876920
877921impl super :: Device for Device {
878922 fn set_buffer_direction ( & mut self , direction : & CorrelationDirection ) -> Result < ( ) , GpuError > {
879- let descriptor_sets = & self . descriptor_sets ;
880- let buffers = & self . buffers ( ) ?;
881- let create_buffer_infos = |buffers : & [ Buffer ] | {
882- buffers
883- . iter ( )
884- . map ( |buf| {
885- vk:: DescriptorBufferInfo :: default ( )
886- . buffer ( buf. buffer )
887- . offset ( 0 )
888- . range ( vk:: WHOLE_SIZE )
889- } )
890- . collect :: < Vec < _ > > ( )
891- } ;
892- let ( buffer_internal_img1, buffer_internal_img2, buffer_out, buffer_out_reverse) =
893- match direction {
894- CorrelationDirection :: Forward => (
895- buffers. buffer_internal_img1 ,
896- buffers. buffer_internal_img2 ,
897- buffers. buffer_out ,
898- buffers. buffer_out_reverse ,
899- ) ,
900- CorrelationDirection :: Reverse => (
901- buffers. buffer_internal_img2 ,
902- buffers. buffer_internal_img1 ,
903- buffers. buffer_out_reverse ,
904- buffers. buffer_out ,
905- ) ,
906- } ;
907- let regular_buffer_infos = create_buffer_infos ( & [
908- buffers. buffer_img ,
909- buffer_internal_img1,
910- buffer_internal_img2,
911- buffers. buffer_internal_int ,
912- buffer_out,
913- buffers. buffer_out_corr ,
914- ] ) ;
915- let regular_write_descriptor = vk:: WriteDescriptorSet :: default ( )
916- . dst_set ( descriptor_sets. descriptor_sets [ 0 ] )
917- . dst_binding ( 0 )
918- . descriptor_type ( vk:: DescriptorType :: STORAGE_BUFFER )
919- . buffer_info ( regular_buffer_infos. as_slice ( ) ) ;
920- let cross_check_buffer_infos = create_buffer_infos ( & [ buffer_out, buffer_out_reverse] ) ;
921- let cross_check_write_descriptor = vk:: WriteDescriptorSet :: default ( )
922- . dst_set ( descriptor_sets. descriptor_sets [ 1 ] )
923- . dst_binding ( 0 )
924- . descriptor_type ( vk:: DescriptorType :: STORAGE_BUFFER )
925- . buffer_info ( cross_check_buffer_infos. as_slice ( ) ) ;
926- let write_descriptors = [ regular_write_descriptor, cross_check_write_descriptor] ;
927- unsafe {
928- self . device . update_descriptor_sets ( & write_descriptors, & [ ] ) ;
929- }
923+ self . direction = direction. to_owned ( ) ;
930924 Ok ( ( ) )
931925 }
932926
@@ -952,6 +946,7 @@ impl super::Device for Device {
952946 vk:: PipelineBindPoint :: COMPUTE ,
953947 pipeline_config. pipeline ,
954948 ) ;
949+ self . set_buffer_layout ( & shader_type) ?;
955950 // It's way easier to map all descriptor sets identically, instead of ensuring that every
956951 // kernel gets to use set = 0.
957952 // The cross correlation kernel will need to switch to descriptor set = 1.
@@ -1107,8 +1102,7 @@ impl DescriptorSets {
11071102 unsafe fn destroy ( & self , device : & ash:: Device ) {
11081103 let _ = device. free_descriptor_sets ( self . descriptor_pool , self . descriptor_sets . as_slice ( ) ) ;
11091104 device. destroy_pipeline_layout ( self . pipeline_layout , None ) ;
1110- device. destroy_descriptor_set_layout ( self . cross_check_layout , None ) ;
1111- device. destroy_descriptor_set_layout ( self . regular_layout , None ) ;
1105+ device. destroy_descriptor_set_layout ( self . layout , None ) ;
11121106 device. destroy_descriptor_pool ( self . descriptor_pool , None ) ;
11131107 }
11141108}
0 commit comments