// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCPP_BARRIER
#define _LIBCPP_BARRIER

/*
    barrier synopsis

namespace std
{

  template<class CompletionFunction = see below>
  class barrier                                   // since C++20
  {
  public:
    using arrival_token = see below;

    static constexpr ptrdiff_t max() noexcept;

    constexpr explicit barrier(ptrdiff_t phase_count,
                               CompletionFunction f = CompletionFunction());
    ~barrier();

    barrier(const barrier&) = delete;
    barrier& operator=(const barrier&) = delete;

    [[nodiscard]] arrival_token arrive(ptrdiff_t update = 1);
    void wait(arrival_token&& arrival) const;

    void arrive_and_wait();
    void arrive_and_drop();

  private:
    CompletionFunction completion; // exposition only
  };

}

*/

#if __cplusplus < 201103L && defined(_LIBCPP_USE_FROZEN_CXX03_HEADERS)
#  include <__cxx03/barrier>
#else
#  include <__config>

#  if _LIBCPP_HAS_THREADS

#    include <__assert>
#    include <__atomic/atomic.h>
#    include <__atomic/memory_order.h>
#    include <__cstddef/ptrdiff_t.h>
#    include <__memory/unique_ptr.h>
#    include <__thread/poll_with_backoff.h>
#    include <__thread/timed_backoff_policy.h>
#    include <__utility/move.h>
#    include <cstdint>
#    include <limits>
#    include <version>

#    if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
#      pragma GCC system_header
#    endif

_LIBCPP_PUSH_MACROS
#    include <__undef_macros>

#    if _LIBCPP_STD_VER >= 20

_LIBCPP_BEGIN_NAMESPACE_STD

struct __empty_completion {
  inline _LIBCPP_HIDE_FROM_ABI void operator()() noexcept {}
};

/*

The default implementation of __barrier_base is a classic tree barrier.

It looks different from literature pseudocode for two main reasons:
 1. Threads that call into std::barrier functions do not provide indices,
    so a numbering step is added before the actual barrier algorithm,
    appearing as an N+1 round to the N rounds of the tree barrier.
 2. A great deal of attention has been paid to avoid cache line thrashing
    by flattening the tree structure into cache-line sized arrays, that
    are indexed in an efficient way.

*/

using __barrier_phase_t _LIBCPP_NODEBUG = uint8_t;

class __barrier_algorithm_base;

_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI __barrier_algorithm_base*
__construct_barrier_algorithm_base(ptrdiff_t& __expected);

_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI bool
__arrive_barrier_algorithm_base(__barrier_algorithm_base* __barrier, __barrier_phase_t __old_phase) noexcept;

_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI void
__destroy_barrier_algorithm_base(__barrier_algorithm_base* __barrier) noexcept;

template <class _CompletionF>
class __barrier_base {
  ptrdiff_t __expected_;
  unique_ptr<__barrier_algorithm_base, void (*)(__barrier_algorithm_base*)> __base_;
  atomic<ptrdiff_t> __expected_adjustment_;
  _CompletionF __completion_;
  atomic<__barrier_phase_t> __phase_;

public:
  using arrival_token = __barrier_phase_t;

  static _LIBCPP_HIDE_FROM_ABI constexpr ptrdiff_t max() noexcept { return numeric_limits<ptrdiff_t>::max(); }

  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI
  __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
      : __expected_(__expected),
        __base_(std::__construct_barrier_algorithm_base(this->__expected_), &__destroy_barrier_algorithm_base),
        __expected_adjustment_(0),
        __completion_(std::move(__completion)),
        __phase_(0) {}
  [[nodiscard]] _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update) {
    _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
        __update <= __expected_, "update is greater than the expected count for the current barrier phase");

    auto const __old_phase = __phase_.load(memory_order_relaxed);
    for (; __update; --__update)
      if (__arrive_barrier_algorithm_base(__base_.get(), __old_phase)) {
        __completion_();
        __expected_ += __expected_adjustment_.load(memory_order_relaxed);
        __expected_adjustment_.store(0, memory_order_relaxed);
        __phase_.store(__old_phase + 2, memory_order_release);
        __phase_.notify_all();
      }
    return __old_phase;
  }
  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __old_phase) const {
    auto const __test_fn = [this, __old_phase]() -> bool { return __phase_.load(memory_order_acquire) != __old_phase; };
    std::__libcpp_thread_poll_with_backoff(__test_fn, __libcpp_timed_backoff_policy());
  }
  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() {
    __expected_adjustment_.fetch_sub(1, memory_order_relaxed);
    (void)arrive(1);
  }
};

template <class _CompletionF = __empty_completion>
class barrier {
  __barrier_base<_CompletionF> __b_;

public:
  using arrival_token = typename __barrier_base<_CompletionF>::arrival_token;

  static _LIBCPP_HIDE_FROM_ABI constexpr ptrdiff_t max() noexcept { return __barrier_base<_CompletionF>::max(); }

  _LIBCPP_AVAILABILITY_SYNC
  _LIBCPP_HIDE_FROM_ABI explicit barrier(ptrdiff_t __count, _CompletionF __completion = _CompletionF())
      : __b_(__count, std::move(__completion)) {
    _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
        __count >= 0,
        "barrier::barrier(ptrdiff_t, CompletionFunction): barrier cannot be initialized with a negative value");
    _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
        __count <= max(),
        "barrier::barrier(ptrdiff_t, CompletionFunction): barrier cannot be initialized with "
        "a value greater than max()");
  }

  barrier(barrier const&)            = delete;
  barrier& operator=(barrier const&) = delete;

  [[nodiscard]] _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update = 1) {
    _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(__update > 0, "barrier:arrive must be called with a value greater than 0");
    return __b_.arrive(__update);
  }
  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __phase) const {
    __b_.wait(std::move(__phase));
  }
  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_wait() { wait(arrive()); }
  _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() { __b_.arrive_and_drop(); }
};

_LIBCPP_END_NAMESPACE_STD

#    endif // _LIBCPP_STD_VER >= 20

_LIBCPP_POP_MACROS

#  endif // _LIBCPP_HAS_THREADS

#  if !defined(_LIBCPP_REMOVE_TRANSITIVE_INCLUDES) && _LIBCPP_STD_VER <= 20
#    include <atomic>
#    include <concepts>
#    include <iterator>
#    include <memory>
#    include <stdexcept>
#    include <variant>
#  endif
#endif // __cplusplus < 201103L && defined(_LIBCPP_USE_FROZEN_CXX03_HEADERS)

#endif // _LIBCPP_BARRIER
