// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier:  MIT

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>

template <typename DataType>
struct SmoothquantTypeConfig;

template <>
struct SmoothquantTypeConfig<ck_tile::half_t>
{
    using XDataType           = ck_tile::half_t;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = ck_tile::int8_t;
    using ComputeDataType     = float;
};

template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t>
{
    using XDataType           = ck_tile::bf16_t;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = ck_tile::int8_t;
    using ComputeDataType     = float;
};

// runtime args
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
{
};

// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
          ck_tile::index_t Repeat_M_,         // each thread repeat along M
          ck_tile::index_t Repeat_N_,         // each thread repeat along N
          ck_tile::index_t ThreadPerBlock_M_, // num threads along M
          ck_tile::index_t ThreadPerBlock_N_, // num threads along N
          ck_tile::index_t Vector_N_,         // vector size along N
          bool kPadN_,
          bool kTwoPass_>
struct smoothquant_traits_
{
    using DataType = ck_tile::remove_cvref_t<DataType_>;

    static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
    static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

    static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
    static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;

    using BlockTile      = ck_tile::sequence<Block_M, Block_N>;
    using Vector         = ck_tile::sequence<1, Vector_N_>;
    using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;

    using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;

    static constexpr bool kPadN    = kPadN_;
    static constexpr bool kTwoPass = kTwoPass_;
};

template <typename Traits_>
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);

// This is the public API, will be generated by script
struct smoothquant_traits
{
    std::string data_type;
};

template <typename DataType>
float smoothquant(smoothquant_args, const ck_tile::stream_config&);
