#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/ceil_div.h>
#include <ATen/Dispatch.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/Pool.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/macros/Macros.h>
#include <ATen/native/cuda/LaunchUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/max_pool2d_with_indices_native.h>
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
#endif

namespace at::native {
namespace {

__device__ inline int min(int a, int b) {
  return a <= b ? a : b;
}

#ifdef USE_ROCM
#define CUDA_MAX_THREADS 256
#define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched
#define BLOCK_STRIDE_BWD 4 // increasing block_stride to lower # of blocks launched
#else
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
#define BLOCK_STRIDE_FWD 2 // increasing block_stride to lower # of blocks launched
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
#endif

template <typename index_t>
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
  const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
  return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
}

template <typename index_t>
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
  return std::min((size + pad) / stride + 1, pooled_size);
}

static inline bool can_use_int32_nhwc(
    int64_t nbatch, int64_t channels,
    int64_t height, int64_t width,
    int64_t pooled_height, int64_t pooled_width,
    int64_t in_stride_n, int64_t in_stride_c,
    int64_t in_stride_h, int64_t in_stride_w)
{
  constexpr int64_t int_max = std::numeric_limits<int>::max();

  int64_t max_intra_batch =
      (height ? (height - 1) * in_stride_h : 0) +
      (width ? (width - 1) * in_stride_w : 0) +
      (channels? (channels - 1) * in_stride_c : 0);

  int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;

  if (max_input_offset > int_max) return false;

  int64_t out_batch_stride = pooled_height * pooled_width * channels;
  if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;

  if (height * width > int_max) return false;

  return true;
}

static inline bool can_use_int32_nchw(
    int64_t nbatch, int64_t channels,
    int64_t height, int64_t width,
    int64_t pooled_height, int64_t pooled_width) {
  int64_t hw = height * width;
  return can_use_int32_nhwc(
      nbatch, channels, height, width,
      pooled_height, pooled_width,
      channels * hw,  // in_stride_n
      hw, // in_stride_c
      width, // in_stride_h
      1 // in_stride_w
  );
}

// kernels borrowed from Caffe
template <typename scalar_t, typename index_t>
__global__ void max_pool_forward_nchw(
    const index_t nthreads,
    const scalar_t* bottom_data,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int pooled_height,
    const int pooled_width,
    const int kernel_h, const int kernel_w,
    const int stride_h, const int stride_w,
    const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w,
    scalar_t* top_data,
    int64_t* top_mask) {
  CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
    index_t pw = index % pooled_width;
    index_t ph = (index / pooled_width) % pooled_height;
    index_t c = (index / pooled_width / pooled_height) % channels;
    index_t n = index / pooled_width / pooled_height / channels;
    index_t hstart = ph * stride_h - pad_h;
    index_t wstart = pw * stride_w - pad_w;
    index_t hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
    index_t wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
    while(hstart < 0)
      hstart += dilation_h;
    while(wstart < 0)
      wstart += dilation_w;
    scalar_t maxval = at::numeric_limits<scalar_t>::lower_bound(); // -Infinity
    index_t maxidx = hstart * width + wstart;
    const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width;
    for (int h = hstart; h < hend; h += dilation_h) {
      for (int w = wstart; w < wend; w += dilation_w) {
        scalar_t val = btm_data[h * width + w];
        if ((val > maxval) || at::_isnan(val)) {
          maxidx = h * width + w;
          maxval = val;
        }
      }
    }
    top_data[index] = maxval;
    top_mask[index] = maxidx;
  }
}

template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
__global__ void max_pool_forward_nhwc(
    const scalar_t* bottom_data,
    const int nbatch,
    const index_t channels, const index_t height, const index_t width,
    const index_t pooled_height, const index_t pooled_width,
    const int kernel_h, const int kernel_w, const int stride_h,
    const int stride_w, const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w,
    const index_t in_stride_n, const index_t in_stride_c,
    const index_t in_stride_h, const index_t in_stride_w,
    const int kernel_stride_C, const int kernel_size_C,
    scalar_t* top_data, int64_t* top_mask) {

  extern __shared__ unsigned char smem_raw[];
  index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
  scalar_t *out_cached = reinterpret_cast<scalar_t*>(
      out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);

  // flattening cta for pre-computation & smem initialization;
  int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
  int block_size = blockDim.x * blockDim.y * blockDim.z;

  // use shared memory to store temporary output value. This is simply to
  // reduce register usage.
  for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
    out_cached[i] = at::numeric_limits<scalar_t>::lower_bound();
    out_mask_cached[i] = 0;
  }

  __syncthreads();

  int batch_id = blockIdx.x % nbatch;
  int channel_id = blockIdx.x / nbatch;
  int channel_offset = threadIdx.x + channel_id * blockDim.x;

  top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
  top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
  bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;

  out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
  out_mask_cached  += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;

  int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
  int oW = (static_cast<int>(pooled_width)  + gridDim.y - 1) / gridDim.y;
  int ostartH = threadIdx.z + blockIdx.z*oH;
  int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
  int ostartW = threadIdx.y + blockIdx.y*oW;
  int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));

  for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
    index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
    index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
    for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
      index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
      index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
      while(hstart < 0)
        hstart += dilation_h;
      while(wstart < 0)
        wstart += dilation_w;

#if defined (USE_ROCM)
// Max h,w and c for using prefetch path
#define MAXh 3
#define MAXw 3
#define MAXc 1
      // Prefetch if conditions met...
      if (kernel_h<=MAXh &&
          kernel_w<=MAXw &&
          channels<=MAXc*(blockDim.x*kernel_stride_C)) {
        scalar_t val [MAXh][MAXw][MAXc] = {};
        for (int ih = 0; ih < MAXh; ih++) {
          int ih_ = ih*dilation_h+hstart;
          for (int iw = 0; iw < MAXw; iw++) {
            int iw_ = iw*dilation_w+wstart;
            const scalar_t *ptr_input = bottom_data + ih_ * in_stride_h + iw_ * in_stride_w;
            for(int c = 0; c < MAXc; c++) {
              int c_ = c*blockDim.x*kernel_stride_C+channel_offset;
              if (ih_>=hend || iw_>=wend || c_>=channels) continue;
              val[ih][iw][c] = ptr_input[c_*in_stride_c];
            }
          }
        }
        for (int ih = 0; ih < MAXh; ih++) {
          int ih_ = ih*dilation_h+hstart;
          for (int iw = 0; iw < MAXw; iw++) {
            int iw_ = iw*dilation_w+wstart;
            int cached_index = threadIdx.x;
            for(int c = 0; c < MAXc; c++) {
              int c_ = c*blockDim.x*kernel_stride_C+channel_offset;
              if (ih_>=hend || iw_>=wend || c_>=channels) continue;
              if ((val[ih][iw][c] > out_cached[cached_index]) || at::_isnan(val[ih][iw][c])) {
                out_cached[cached_index] = val[ih][iw][c];
                out_mask_cached[cached_index] = ih_ * width + iw_;
              }
              cached_index += blockDim.x;
            }
          }
        }
      }
      // Else do it Non-Prefetch...
      else
#endif
      for (index_t ih = hstart; ih < hend; ih += dilation_h) {
        for (index_t iw = wstart; iw < wend; iw += dilation_w) {
          int cached_index = threadIdx.x;
          const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
          for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
            scalar_t val = ptr_input[c * in_stride_c];
            if ((val > out_cached[cached_index]) || at::_isnan(val)) {
              out_cached[cached_index] = val;
              out_mask_cached[cached_index] = ih * width + iw;
            }
            cached_index += blockDim.x;
          }
        }
      }

      scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
      int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;

      int cached_index = threadIdx.x;
      for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
        ptr_output_data[c] = out_cached[cached_index];
        ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
        out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
        out_mask_cached[cached_index] = index_t(0);
        cached_index += blockDim.x;
      }
    }
  }
}


static constexpr int BLOCK_THREADS = 256;

template <typename scalar_t, typename accscalar_t, typename index_t>
#if defined (USE_ROCM)
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4)
#else
C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8)
#endif
__global__ void max_pool_backward_nchw(
    const scalar_t* top_diff,
    const int64_t* top_mask,
    const index_t num,
    const index_t channels,
    const index_t height,
    const index_t width,
    const index_t pooled_height,
    const index_t pooled_width,
    const int kernel_h, const int kernel_w,
    const int stride_h, const int stride_w,
    const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w,
    scalar_t* bottom_diff) {
  CUDA_KERNEL_LOOP_TYPE(index, height*width, index_t) {
    index_t h = index / width;
    index_t w = index - h * width;
    index_t phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
    index_t phend = p_end(h, pad_h, pooled_height, stride_h);
    index_t pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
    index_t pwend = p_end(w, pad_w, pooled_width, stride_w);
    for (index_t n = blockIdx.y; n < num; n += gridDim.y) {
      for (index_t c = blockIdx.z; c < channels; c += gridDim.z) {
        accscalar_t gradient = accscalar_t(0);
        index_t offset = (n * channels + c) * pooled_height * pooled_width;
        for (index_t ph = phstart; ph < phend; ++ph) {
          for (index_t pw = pwstart; pw < pwend; ++pw) {
            if (top_mask[ph * pooled_width + pw + offset] == h * width + w) {
              gradient += static_cast<accscalar_t>(top_diff[ph * pooled_width + pw + offset]);
            }
          }
        }
        bottom_diff[(n*channels+c)*height*width+index] = static_cast<scalar_t>(gradient);
      }
    }
  }
}

template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
__global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
                                    const int64_t* top_mask, const int nbatch, const int64_t channels,
                                    const int64_t height, const int64_t width, const int pooled_height,
                                    const int pooled_width, const int kernel_h, const int kernel_w,
                                    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
                                    const int dilation_h, const int dilation_w,
                                    const int out_stride_c, const int out_stride_h, const int out_stride_w,
                                    const int kernel_stride_C, const int kernel_size_C,
                                    scalar_t* bottom_diff) {
  extern __shared__ int smem[];
  accscalar_t *out_cached = reinterpret_cast<accscalar_t*>(smem);

  int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
  int block_size = blockDim.x * blockDim.y * blockDim.z;

  int batch_id = blockIdx.x % nbatch;
  int channel_id = blockIdx.x / nbatch;
  int channel_offset = threadIdx.x + channel_id * blockDim.x;

  for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) {
    out_cached[i] = accscalar_t(0.0);
  }

  __syncthreads();

  out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];

  bottom_diff = bottom_diff + batch_id * height * width * channels;
  top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
  top_diff = top_diff + batch_id * pooled_height * pooled_width * channels;

  int iH = (height + gridDim.z-1) / gridDim.z;
  int iW = (width + gridDim.y-1) / gridDim.y;
  int istartH = threadIdx.z + blockIdx.z*iH;
  int iendH = ::min(static_cast<int64_t>(istartH)+iH, height);
  int istartW = threadIdx.y + blockIdx.y*iW;
  int iendW = ::min(static_cast<int64_t>(istartW)+iW, width);

  for (int ih = istartH; ih < iendH; ih+=blockDim.z) {
    int phstart = p_start(ih, pad_h, kernel_h, dilation_h, stride_h);
    int phend = p_end(ih, pad_h, pooled_height, stride_h);
    for (int iw = istartW; iw < iendW; iw+=blockDim.y) {
      int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w);
      int pwend = p_end(iw, pad_w, pooled_width, stride_w);
      int index_shift = ih * width + iw;
      if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {

#if defined (USE_ROCM)
#define _MAXh 2
#define _MAXw 2
        if (phend-phstart<=_MAXh && pwend-pwstart<=_MAXw) {
          int msk[_MAXh][_MAXw];
          scalar_t tpd[_MAXh][_MAXw];
          int cached_index = threadIdx.x;
#pragma unroll
          for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
#pragma unroll
            for(int oh = 0; oh < _MAXh; ++oh) {
#pragma unroll
              for(int ow = 0; ow < _MAXw; ++ow) {
                int oh_ = oh+phstart;
                int ow_ = ow+pwstart;
                const int64_t* ptr_top_mask = top_mask + oh_ * out_stride_h + ow_ * out_stride_w;
                if (oh_ >= phend || ow_ >= pwend) {
                  msk[oh][ow] = ~index_shift;
                } else {
                  msk[oh][ow] = ptr_top_mask[c*out_stride_c];
                  tpd[oh][ow] = top_diff[oh_ * out_stride_h + ow_ * out_stride_w + c*out_stride_c];
                }
              }
            }

            accscalar_t acm = 0;
#pragma unroll
            for(int oh = 0; oh < _MAXh; ++oh) {
#pragma unroll
              for(int ow = 0; ow < _MAXw; ++ow) {
                if (msk[oh][ow] == index_shift) {
                  acm += static_cast<accscalar_t>(tpd[oh][ow]);
                }
              }
            }
            out_cached[cached_index] += acm;
            cached_index += blockDim.x;
          }
        }
        else
#undef _MAXh
#undef _MAXw
#endif

        for(int oh = phstart; oh < phend; ++oh) {
          for(int ow = pwstart; ow < pwend; ++ow) {
            int cached_index = threadIdx.x;
            const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w;
            for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
              if (ptr_top_mask[c*out_stride_c] == index_shift) {
                out_cached[cached_index] +=
                  static_cast<accscalar_t>(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]);
              }
              cached_index += blockDim.x;
            }
          }
        }
        scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
        int cached_index = threadIdx.x;
        for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
          ptr_bottom_diff[c] = static_cast<scalar_t>(out_cached[cached_index]);
          out_cached[cached_index] = accscalar_t(0.0);
          cached_index += blockDim.x;
        }
      } else {
        const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w;
        scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels;
        int cached_index = threadIdx.x;
        for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
          if (ptr_top_mask[c*out_stride_c] == index_shift) {
            ptr_bottom_diff[c] =
              static_cast<scalar_t>(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]);
          }
          cached_index += blockDim.x;
        }
      }
    }
  }
}

} // namespace

TORCH_IMPL_FUNC(max_pool2d_with_indices_out_cuda)
(const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
  NoNamesGuard guard;

  TensorArg output_arg{ output, "output", 1 };
  TensorArg indices_arg{ indices, "indices", 2 };
  TensorArg input_arg{ input_, "input_", 3 };

  checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg});
  if (output.numel() == 0) {
    return;
  }

  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
  const int dW = stride.empty() ? kW :
                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);

  const int padH = safe_downcast<int, int64_t>(padding[0]);
  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);

  const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
  const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);

  const auto memory_format = input_.suggest_memory_format();

  const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
  const int64_t nInputPlane = input_.size(-3);
  const int64_t inputHeight = input_.size(-2);
  const int64_t inputWidth = input_.size(-1);

  const int64_t outputHeight = output.size(-2);
  const int64_t outputWidth = output.size(-1);

  Tensor input = input_.contiguous(memory_format);

  const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0;
  const int64_t in_stride_c = input.stride(-3);
  const int64_t in_stride_h = input.stride(-2);
  const int64_t in_stride_w = input.stride(-1);

  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
    "max_pool2d_with_indices_out_cuda_frame",
    [&] {
      using accscalar_t = acc_type<scalar_t, true>;

      scalar_t *output_data = output.mutable_data_ptr<scalar_t>();
      const scalar_t *input_data = input.const_data_ptr<scalar_t>();
      int64_t *indices_data = indices.mutable_data_ptr<int64_t>();

      switch (memory_format) {
        case MemoryFormat::ChannelsLast: {
          const int max_threads = std::min<int>(
              at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
          int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
          int block_x = std::min<int>(
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
          int block_y = std::min<int>(
              maxThreadsDim[1], std::min<int>(lastPow2(outputWidth), max_threads / block_x));
          int block_z = std::min<int>(
              maxThreadsDim[2], std::min<int>(lastPow2(outputHeight), max_threads / block_x / block_y));
          block_x = std::min<int>(
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
          const dim3 block(block_x, block_y, block_z);

          bool use_int32 = can_use_int32_nhwc(
              nbatch, nInputPlane, inputHeight, inputWidth,
              outputHeight, outputWidth,
              in_stride_n, in_stride_c, in_stride_h, in_stride_w);

          int kernel_stride_C = ceil_div(
              safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
          int kernel_size_C = ceil_div(
              safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);

          int grid_x = nbatch*kernel_stride_C;
          int grid_y = std::min<int>(
              at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
              ceil_div(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE_FWD));
          int grid_z = std::min<int>(
              at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
              ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
          const dim3 grid(grid_x, grid_y, grid_z);

          size_t shmem_size;
          size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;

          if (use_int32) {
            shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
                        "shared memory too small");
            max_pool_forward_nhwc<scalar_t, int32_t>
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
                input_data, static_cast<int>(nbatch),
                static_cast<int32_t>(nInputPlane),
                static_cast<int32_t>(inputHeight),
                static_cast<int32_t>(inputWidth),
                static_cast<int32_t>(outputHeight),
                static_cast<int32_t>(outputWidth),
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                static_cast<int32_t>(in_stride_n),
                static_cast<int32_t>(in_stride_c),
                static_cast<int32_t>(in_stride_h),
                static_cast<int32_t>(in_stride_w),
                kernel_stride_C, kernel_size_C,
                output_data, indices_data);
          } else {
            shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
            TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
                        "shared memory too small");
            max_pool_forward_nhwc<scalar_t, int64_t>
              <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
                input_data, static_cast<int>(nbatch),
                nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
                kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                in_stride_n, in_stride_c, in_stride_h, in_stride_w,
                kernel_stride_C, kernel_size_C,
                output_data, indices_data);
          }
          C10_CUDA_KERNEL_LAUNCH_CHECK();
          break;
        }
        case MemoryFormat::Contiguous: {
          const int threads = std::min(
              at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
              BLOCK_THREADS);
          const int64_t nthreads = output.numel();
          bool use_int32 = can_use_int32_nchw(
              nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
          const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
          const int blocks = static_cast<int>(std::min<int64_t>(
              ceil_div(nthreads, static_cast<int64_t>(threads)),
              static_cast<int64_t>(maxGridX)));
          auto stream = at::cuda::getCurrentCUDAStream();
          if (use_int32) {
            max_pool_forward_nchw<scalar_t, int32_t>
                <<<blocks, threads, 0, stream>>>(
                    static_cast<int32_t>(nthreads),
                    input_data,
                    static_cast<int32_t>(nInputPlane),
                    static_cast<int32_t>(inputHeight),
                    static_cast<int32_t>(inputWidth),
                    static_cast<int32_t>(outputHeight),
                    static_cast<int32_t>(outputWidth),
                    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                    output_data, indices_data);
          } else {
            max_pool_forward_nchw<scalar_t, int64_t>
                <<<blocks, threads, 0, stream>>>(
                    nthreads,
                    input_data,
                    nInputPlane,
                    inputHeight,
                    inputWidth,
                    outputHeight,
                    outputWidth,
                    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                    output_data, indices_data);
          }
          C10_CUDA_KERNEL_LAUNCH_CHECK();
          break;
        }
        default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
      }
    }
  );
}

TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_cuda)
(const Tensor& gradOutput_,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices_,
const Tensor& gradInput) {
  NoNamesGuard guard;

  TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
  TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 };
  TensorArg input_arg{ input_, "input_", 3 };
  TensorArg indices_arg{ indices_, "indices", 4 };

  checkAllSameGPU(__func__,
                  {gradInput_arg, gradOutput_arg, input_arg, indices_arg});
  if (gradOutput_.numel() == 0) {
    return;
  }

  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
  const int dW = stride.empty() ? kW :
                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);

  const int padH = safe_downcast<int, int64_t>(padding[0]);
  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);

  const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
  const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);

  const auto memory_format = input_.suggest_memory_format();

  const Tensor input = input_.contiguous(memory_format);

  const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
  const int64_t nInputPlane = input.size(-3);
  const int64_t inputHeight = input.size(-2);
  const int64_t inputWidth = input.size(-1);

  const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0;
  const int64_t in_stride_c = input.stride(-3);
  const int64_t in_stride_h = input.stride(-2);
  const int64_t in_stride_w = input.stride(-1);

  const Tensor gradOutput = gradOutput_.contiguous(memory_format);

  const int64_t outputHeight = gradOutput.size(-2);
  const int64_t outputWidth = gradOutput.size(-1);

  const int64_t out_stride_c = gradOutput.stride(-3);
  const int64_t out_stride_h = gradOutput.stride(-2);
  const int64_t out_stride_w = gradOutput.stride(-1);

  const Tensor indices = indices_.contiguous(memory_format);

  gradInput.zero_();

  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
    "max_pool2d_with_indices_out_cuda_frame",
    [&] {
      using accscalar_t = acc_type<scalar_t, true>;

      const scalar_t *gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
      scalar_t *gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
      const int64_t *indices_data = indices.const_data_ptr<int64_t>();

      switch (memory_format) {
        case MemoryFormat::ChannelsLast: {
          const int max_threads = std::min<int>(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
          int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
          int block_x = std::min<int>(
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), at::cuda::warp_size()));
          int block_y = std::min<int>(
              maxThreadsDim[1], std::min<int>(lastPow2(inputWidth), max_threads / block_x));
          int block_z = std::min<int>(
              maxThreadsDim[2], std::min<int>(lastPow2(inputHeight), max_threads / block_x / block_y));
          block_x = std::min<int>(
              maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
          const dim3 block(block_x, block_y, block_z);

          int kernel_stride_C = ceil_div(
              safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
          int kernel_size_C = ceil_div(
              safe_downcast<int, int64_t>(nInputPlane), block_x * kernel_stride_C);

          int grid_x = nbatch*kernel_stride_C;
          int grid_y = std::min<int>(
              at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
              ceil_div(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE_BWD));
          int grid_z = std::min<int>(
              at::cuda::getCurrentDeviceProperties()->maxGridSize[2],
              ceil_div(safe_downcast<int, int64_t>(inputHeight), block_z*BLOCK_STRIDE_BWD));
          const dim3 grid(grid_x, grid_y, grid_z);

          size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t);
          AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);

          // The backward kernel is launched on input instead output.
          // If it is launched on output layer, atomic_add would not provide much benefit on FP16.
          // Please check comments at https://github.com/pytorch/pytorch/pull/34519.
          max_pool_backward_nhwc<scalar_t, accscalar_t>
          <<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
                  gradOutput_data,
                  indices_data,
                  nbatch,
                  nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
                  kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                  out_stride_c, out_stride_h, out_stride_w,
                  kernel_stride_C, kernel_size_C,
                  gradInput_data);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
          break;
        }
        case MemoryFormat::Contiguous: {
          const int threads = std::min(
              at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock,
              BLOCK_THREADS);
          const int imgcount = inputWidth * inputHeight;
          const int maxGridX = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
          const int maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
          const int maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2];
          const int blocks_x = std::min(ceil_div(imgcount, threads), maxGridX);
          dim3 grid(blocks_x, static_cast<unsigned>(std::min<int64_t>(nbatch, maxGridY)), static_cast<unsigned>(std::min<int64_t>(nInputPlane, maxGridZ)));
          bool use_int32 = can_use_int32_nchw(
              nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth);
          auto stream = at::cuda::getCurrentCUDAStream();
          if (use_int32) {
            max_pool_backward_nchw<scalar_t, accscalar_t, int32_t>
                <<<grid, threads, 0, stream>>>(
                    gradOutput_data,
                    indices_data,
                    static_cast<int32_t>(nbatch),
                    static_cast<int32_t>(nInputPlane),
                    static_cast<int32_t>(inputHeight),
                    static_cast<int32_t>(inputWidth),
                    static_cast<int32_t>(outputHeight),
                    static_cast<int32_t>(outputWidth),
                    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                    gradInput_data);
          } else {
            max_pool_backward_nchw<scalar_t, accscalar_t, int64_t>
                <<<grid, threads, 0, stream>>>(
                    gradOutput_data,
                    indices_data,
                    nbatch,
                    nInputPlane,
                    inputHeight,
                    inputWidth,
                    outputHeight,
                    outputWidth,
                    kH, kW, dH, dW, padH, padW, dilationH, dilationW,
                    gradInput_data);
          }
          C10_CUDA_KERNEL_LAUNCH_CHECK();
          break;
        }
        default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
      }
    }
  );
}

} // at::native
