@@ -16,12 +16,13 @@ struct LaunchArgs {
1616 int num_threads;
1717 int smem_size;
1818 int cluster_dim;
19+ bool enable_pdl;
1920
20- LaunchArgs (const int & grid_dim_x, const int & num_threads, const int & smem_size = 0 , const int & cluster_dim = 1 ):
21- grid_dim ({grid_dim_x, 1 }), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
21+ LaunchArgs (const int & grid_dim_x, const int & num_threads, const int & smem_size = 0 , const int & cluster_dim = 1 , const bool & enable_pdl = true ):
22+ grid_dim ({grid_dim_x, 1 }), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
2223
23- LaunchArgs (const std::pair<int , int >& grid_dim, const int & num_threads, const int & smem_size = 0 , const int & cluster_dim = 1 ):
24- grid_dim (grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
24+ LaunchArgs (const std::pair<int , int >& grid_dim, const int & num_threads, const int & smem_size = 0 , const int & cluster_dim = 1 , const bool & enable_pdl = true ):
25+ grid_dim (grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
2526};
2627
2728class KernelRuntime final {
@@ -127,20 +128,24 @@ class LaunchRuntime {
127128 static void launch (const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
128129 const auto kernel = kernel_runtime->kernel ;
129130 const auto stream = at::cuda::getCurrentCUDAStream ();
130- const LaunchArgs launch_args = args.launch_args ;
131+ LaunchArgs launch_args = args.launch_args ;
132+
133+ // Allow runtime override from Python.
134+ // NOTES: the default is enabled.
135+ launch_args.enable_pdl = device_runtime->get_pdl ();
131136
132137 const dim3 grid_dim = {static_cast <unsigned >(launch_args.grid_dim .first ),
133138 static_cast <unsigned >(launch_args.grid_dim .second ),
134139 1 };
135140 const dim3 block_dim = {static_cast <unsigned >(launch_args.num_threads ), 1 , 1 };
136141 auto config = construct_launch_config (kernel, stream, launch_args.smem_size ,
137- grid_dim, block_dim, launch_args.cluster_dim );
142+ grid_dim, block_dim, launch_args.cluster_dim , launch_args. enable_pdl );
138143
139144 // Launch in the derived class
140145 if (get_env<int >(" DG_JIT_DEBUG" )) {
141- printf (" Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n " ,
146+ printf (" Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, enable_pdl: %d, stream: %ld\n " ,
142147 launch_args.grid_dim .first , launch_args.grid_dim .second , launch_args.num_threads ,
143- launch_args.smem_size , launch_args.cluster_dim , stream.id ());
148+ launch_args.smem_size , launch_args.cluster_dim , launch_args. enable_pdl , stream.id ());
144149 }
145150 Derived::launch_impl (kernel, config, args);
146151 }
0 commit comments