From 49cfc79b83af125dfc19ee724fd031fecc7550ee Mon Sep 17 00:00:00 2001 From: Carsten Patzke <carsten.patzke@desy.de> Date: Fri, 3 Apr 2020 17:30:30 +0200 Subject: [PATCH] [asapo::fabric] Lot of changes. But in general implemented a better timeout detection --- .../cpp/include/asapo_fabric/fabric_error.h | 5 + common/cpp/src/asapo_fabric/CMakeLists.txt | 6 +- .../common/fabric_context_impl.cpp | 61 +++++++--- .../asapo_fabric/common/fabric_context_impl.h | 106 ++++++++++++++---- .../task/fabric_alive_check_response_task.cpp | 28 +++++ .../task/fabric_alive_check_response_task.h | 25 +++++ .../task/fabric_self_deleting_task.cpp | 5 +- .../task/fabric_self_deleting_task.h | 4 +- .../task/fabric_self_requeuing_task.cpp | 47 ++++++++ .../common/task/fabric_self_requeuing_task.h | 38 +++++++ .../common/{ => task}/fabric_task.h | 2 +- .../{ => task}/fabric_waitable_task.cpp | 8 +- .../common/{ => task}/fabric_waitable_task.h | 2 +- .../asapo_fabric/fabric_internal_error.cpp | 5 +- .../server/fabric_server_impl.cpp | 5 +- .../task/fabric_handshake_accepting_task.cpp | 15 ++- .../task/fabric_handshake_accepting_task.h | 4 +- .../server/task/fabric_recv_any_task.cpp | 2 +- .../server/task/fabric_recv_any_task.h | 4 +- .../client_lazy_initialization.cpp | 3 +- .../asapo_fabric/parallel_data_transfer.cpp | 29 +++-- .../asapo_fabric/server_not_running.cpp | 18 ++- .../asapo_fabric/simple_data_transfer.cpp | 86 ++++++++------ tests/automatic/asapo_fabric/timeout_test.cpp | 89 +++++++++++---- .../asapo_fabric/wrong_memory_info.cpp | 26 ++++- 25 files changed, 485 insertions(+), 138 deletions(-) create mode 100644 common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp create mode 100644 common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h rename common/cpp/src/asapo_fabric/{server => common}/task/fabric_self_deleting_task.cpp (71%) rename common/cpp/src/asapo_fabric/{server => common}/task/fabric_self_deleting_task.h (79%) create mode 100644 common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp create mode 100644 common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h rename common/cpp/src/asapo_fabric/common/{ => task}/fabric_task.h (80%) rename common/cpp/src/asapo_fabric/common/{ => task}/fabric_waitable_task.cpp (80%) rename common/cpp/src/asapo_fabric/common/{ => task}/fabric_waitable_task.h (89%) diff --git a/common/cpp/include/asapo_fabric/fabric_error.h b/common/cpp/include/asapo_fabric/fabric_error.h index 6ae2fc377..854d677f4 100644 --- a/common/cpp/include/asapo_fabric/fabric_error.h +++ b/common/cpp/include/asapo_fabric/fabric_error.h @@ -8,6 +8,7 @@ enum class FabricErrorType { kOutdatedLibrary, kInternalError, // An error that was produced by LibFabric kInternalOperationCanceled, // An error that was produced by LibFabric + kInternalConnectionError, // This might occur when the connection is unexpectedly closed kNoDeviceFound, kClientNotInitialized, kTimeout, @@ -46,6 +47,10 @@ auto const kConnectionRefusedError = FabricErrorTemplate { "Connection refused", FabricErrorType::kConnectionRefused }; +auto const kInternalConnectionError = FabricErrorTemplate { + "Connection error (maybe a disconnect?)", + FabricErrorType::kInternalConnectionError +}; } } diff --git a/common/cpp/src/asapo_fabric/CMakeLists.txt b/common/cpp/src/asapo_fabric/CMakeLists.txt index b504d5a72..fefef7639 100644 --- a/common/cpp/src/asapo_fabric/CMakeLists.txt +++ b/common/cpp/src/asapo_fabric/CMakeLists.txt @@ -10,12 +10,14 @@ IF(ENABLE_LIBFABRIC) fabric_factory_impl.cpp common/fabric_context_impl.cpp common/fabric_memory_region_impl.cpp - common/fabric_waitable_task.cpp + common/task/fabric_waitable_task.cpp + common/task/fabric_self_deleting_task.cpp + common/task/fabric_self_requeuing_task.cpp + common/task/fabric_alive_check_response_task.cpp client/fabric_client_impl.cpp server/fabric_server_impl.cpp server/task/fabric_recv_any_task.cpp server/task/fabric_handshake_accepting_task.cpp - server/task/fabric_self_deleting_task.cpp ) ELSE() set(SOURCE_FILES ${SOURCE_FILES} diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp index f7355db61..51d627f11 100644 --- a/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp +++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp @@ -32,8 +32,7 @@ std::string __PRETTY_FUNCTION_TO_NAMESPACE__(const std::string& prettyFunction) // TODO: It is super important that version 1.10 is installed, but since its not released yet we go with 1.9 const uint32_t FabricContextImpl::kMinExpectedLibFabricVersion = FI_VERSION(1, 9); -FabricContextImpl::FabricContextImpl() : io__{ GenerateDefaultIO() } { - +FabricContextImpl::FabricContextImpl() : io__{ GenerateDefaultIO() }, alive_check_response_task_(this) { } FabricContextImpl::~FabricContextImpl() { @@ -69,12 +68,13 @@ void FabricContextImpl::InitCommon(const std::string& networkIpHint, uint16_t se if (networkIpHint == "127.0.0.1") { // sockets mode hints->fabric_attr->prov_name = strdup("sockets"); - hints->ep_attr->type = FI_EP_RDM; + hotfix_using_sockets_ = true; } else { // verbs mode hints->fabric_attr->prov_name = strdup("verbs;ofi_rxm"); - hints->caps = FI_TAGGED | FI_RMA | FI_DIRECTED_RECV | additionalFlags; } + hints->ep_attr->type = FI_EP_RDM; + hints->caps = FI_TAGGED | FI_RMA | FI_DIRECTED_RECV | additionalFlags; if (isServer) { hints->src_addr = strdup(networkIpHint.c_str()); @@ -168,32 +168,32 @@ std::unique_ptr<FabricMemoryRegion> FabricContextImpl::ShareMemoryRegion(void* s void FabricContextImpl::Send(FabricAddress dstAddress, FabricMessageId messageId, const void* src, size_t size, Error* error) { - HandleFiCommandWithBasicTaskAndWait(fi_tsend, error, - endpoint_, src, size, nullptr, dstAddress, messageId); + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_tsend, src, size, nullptr, dstAddress, messageId); } void FabricContextImpl::Recv(FabricAddress srcAddress, FabricMessageId messageId, void* dst, size_t size, Error* error) { - HandleFiCommandWithBasicTaskAndWait(fi_trecv, error, - endpoint_, dst, size, nullptr, srcAddress, messageId, 0); + HandleFiCommandWithBasicTaskAndWait(srcAddress, error, + fi_trecv, dst, size, nullptr, srcAddress, messageId, kRecvTaggedExactMatch); } void FabricContextImpl::RawSend(FabricAddress dstAddress, const void* src, size_t size, Error* error) { - HandleFiCommandWithBasicTaskAndWait(fi_send, error, - endpoint_, src, size, nullptr, dstAddress); + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_send, src, size, nullptr, dstAddress); } void FabricContextImpl::RawRecv(FabricAddress srcAddress, void* dst, size_t size, Error* error) { - HandleFiCommandWithBasicTaskAndWait(fi_recv, error, - endpoint_, dst, size, nullptr, srcAddress); + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_recv, dst, size, nullptr, srcAddress); } void FabricContextImpl::RdmaWrite(FabricAddress dstAddress, const MemoryRegionDetails* details, const void* buffer, size_t size, Error* error) { - HandleFiCommandWithBasicTaskAndWait(fi_write, error, - endpoint_, buffer, size, nullptr, dstAddress, details->addr, details->key); + HandleFiCommandWithBasicTaskAndWait(dstAddress, error, + fi_write, buffer, size, nullptr, dstAddress, details->addr, details->key); } @@ -203,6 +203,8 @@ void FabricContextImpl::StartBackgroundThreads() { completion_thread_ = io__->NewThread("ASAPO/FI/CQ", [this]() { CompletionThread(); }); + + alive_check_response_task_.Start(); } void FabricContextImpl::StopBackgroundThreads() { @@ -211,6 +213,8 @@ void FabricContextImpl::StopBackgroundThreads() { completion_thread_->join(); completion_thread_ = nullptr; } + + alive_check_response_task_.Stop(); } void FabricContextImpl::CompletionThread() { @@ -225,29 +229,50 @@ void FabricContextImpl::CompletionThread() { continue; // No data } - if (ret == -FI_EAVAIL) { + // TODO Refactor, maybe put it in other functions and/or switch case of ret + + if (ret == -FI_EAVAIL) { // An error is available fi_cq_err_entry errEntry{}; ret = fi_cq_readerr(completion_queue_, &errEntry, 0); if (ret != 1) { error = ErrorFromFabricInternal("Unknown error while fi_cq_readerr", ret); } else { auto task = (FabricWaitableTask*)(errEntry.op_context); - task->HandleErrorCompletion(&errEntry); + if (task) { + task->HandleErrorCompletion(&errEntry); + } else if (hotfix_using_sockets_) { + printf("[Known Sockets bug libfabric/#5795] Ignoring nullptr task!\n"); + } else { + error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_readerr"); + } } continue; } - if (ret != 1) { + if (ret != 1) { // We expect to receive 1 event error = ErrorFromFabricInternal("Unknown error while fi_cq_readfrom", ret); break; } auto task = (FabricWaitableTask*)(entry.op_context); - task->HandleCompletion(&entry, tmpAddress); + if (task) { + task->HandleCompletion(&entry, tmpAddress); + } else { + error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_sreadfrom"); + } } if (error) { throw std::runtime_error("ASAPO Fabric CompletionThread exited with error: " + error->Explain()); } } + +bool FabricContextImpl::TargetIsAliveCheck(FabricAddress address) { + Error error; + + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, &error, + fi_tsend, nullptr, 0, nullptr, address, FI_ASAPO_TAG_ALIVE_CHECK); + // If the send was successful, then we are still able to communicate with the peer + return !(error != nullptr); +} diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.h b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h index 8bfd14ebb..0925be332 100644 --- a/common/cpp/src/asapo_fabric/common/fabric_context_impl.h +++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h @@ -7,11 +7,35 @@ #include <memory> #include <asapo_fabric/asapo_fabric.h> #include <thread> -#include "fabric_waitable_task.h" +#include "task/fabric_waitable_task.h" #include "../fabric_internal_error.h" +#include "task/fabric_alive_check_response_task.h" namespace asapo { namespace fabric { + +#define FI_ASAPO_ADDR_NO_ALIVE_CHECK FI_ADDR_NOTAVAIL +#define FI_ASAPO_TAG_ALIVE_CHECK ((uint64_t) -1) + +/** + * TODO: State of the bandages used in asapo to use RXM + * If you read this _in the future_ there are hopefully fixes for the following topics: + * Since RXM is connectionless, we do not know when an disconnect occurs. + * - Therefore when we try to receive data, we have added a targetAddress to HandleFiCommandAndWait, + * which might check if the peer is still responding to pings when a timeout occurs. + * + * Another issue is that in order to send data all addresses have to be added in an addressVector, + * unfortunately, this is also required to respond to a request. + * - So we added a handshake procedure that sends the local address of the client with a handshake to the server. + * This could be fixed by FI_SOURCE_ERR, which automatically + * adds new connections the AV which would obsolete the handshake. + * At the time of writing this, FI_SOURCE_ERR is not supported with verbs;ofi_rxm + */ + + +const static uint64_t kRecvTaggedAnyMatch = ~0ULL; +const static uint64_t kRecvTaggedExactMatch = 0; + // TODO Use a serialization framework struct FabricHandshakePayload { // Hostnames can be up to 256 Bytes long. We also need to store the port number. @@ -19,9 +43,14 @@ struct FabricHandshakePayload { }; class FabricContextImpl : public FabricContext { + friend class FabricSelfRequeuingTask; + friend class FabricAliveCheckResponseTask; public: std::unique_ptr<IO> io__; + protected: + FabricAliveCheckResponseTask alive_check_response_task_; + fi_info* fabric_info_{}; fid_fabric* fabric_{}; fid_domain* domain_{}; @@ -29,10 +58,15 @@ class FabricContextImpl : public FabricContext { fid_av* address_vector_{}; fid_ep* endpoint_{}; - uint64_t requestTimeoutMs_ = 10000; // 10 sec should be enough. TODO: maybe make a public variable setter + uint64_t requestEnqueueTimeoutMs_ = 10000; // 10 sec for queuing a task + uint64_t requestTimeoutMs_ = 20000; // 20 sec to complete a task, otherwise a ping will be send + uint32_t maxTimeoutRetires_ = 5; // Timeout retires, if one of them fails, the task will fail with a timeout std::unique_ptr<std::thread> completion_thread_; bool background_threads_running_ = false; + private: + // Unfortunately when a client disconnects on sockets, a weird completion is generated. See libfabric/#5795 + bool hotfix_using_sockets_ = false; public: explicit FabricContextImpl(); virtual ~FabricContextImpl(); @@ -50,7 +84,7 @@ class FabricContextImpl : public FabricContext { void Recv(FabricAddress srcAddress, FabricMessageId messageId, void* dst, size_t size, Error* error) override; - /// Without message id + /// Without message id - No alive check! void RawSend(FabricAddress dstAddress, const void* src, size_t size, Error* error); void RawRecv(FabricAddress srcAddress, @@ -68,45 +102,76 @@ class FabricContextImpl : public FabricContext { void StartBackgroundThreads(); void StopBackgroundThreads(); + // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done + // Alive check is generally only necessary if you are trying to receive data or RDMA send. template<class FuncType, class... ArgTypes> - inline void HandleFiCommandWithBasicTaskAndWait(FuncType func, Error* error, ArgTypes... args) { + inline void HandleFiCommandWithBasicTaskAndWait(FabricAddress targetAddress, Error* error, + FuncType func, ArgTypes... args) { FabricWaitableTask task; - HandleFiCommandAndWait(func, &task, error, args...); + HandleFiCommandAndWait(targetAddress, &task, error, func, args...); } + // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done. + // Alive check is generally only necessary if you are trying to receive data or RDMA send. template<class FuncType, class... ArgTypes> - inline void HandleFiCommandAndWait(FuncType func, FabricWaitableTask* task, Error* error, ArgTypes... args) { - HandleFiCommand(func, task, error, args...); + inline void HandleFiCommandAndWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error, + FuncType func, ArgTypes... args) { + HandleRawFiCommand(task, error, func, args...); + if (!(*error)) { task->Wait(requestTimeoutMs_, error); - - // If a timeout occurs we want to cancel the action, - // which invokes an 'Operation canceled' error in the completion queue. if (*error == FabricErrorTemplates::kTimeout) { - fi_cancel(&endpoint_->fid, task); - task->Wait(0, error); - // We expect the task to fail with 'Operation canceled' - if (*error == FabricErrorTemplates::kInternalOperationCanceledError) { - // Switch it to a timeout so its more clearly what happened - *error = FabricErrorTemplates::kTimeout.Generate(); + if (targetAddress != FI_ASAPO_ADDR_NO_ALIVE_CHECK) { + // Handle advanced alive check + bool aliveCheckFailed = false; + for (uint32_t i = 0; i < maxTimeoutRetires_ && *error == FabricErrorTemplates::kTimeout; i++) { + *error = nullptr; + printf("HandleFiCommandAndWait - Tries: %d\n", i); + if (!TargetIsAliveCheck(targetAddress)) { + aliveCheckFailed = true; + break; + } + task->Wait(requestTimeoutMs_, error); + } + + // TODO refactor this if/else mess + if (aliveCheckFailed) { + fi_cancel(&endpoint_->fid, task); + task->Wait(0, error); + *error = FabricErrorTemplates::kInternalConnectionError.Generate(); + } else if(*error == FabricErrorTemplates::kTimeout) { + fi_cancel(&endpoint_->fid, task); + task->Wait(0, error); + *error = FabricErrorTemplates::kTimeout.Generate(); + } + } else { + // If a timeout occurs we want to cancel the action, + // which invokes an 'Operation canceled' error in the completion queue. + fi_cancel(&endpoint_->fid, task); + task->Wait(0, error); + // We expect the task to fail with 'Operation canceled' + if (*error == FabricErrorTemplates::kInternalOperationCanceledError) { + // Switch it to a timeout so its more clearly what happened + *error = FabricErrorTemplates::kTimeout.Generate(); + } } } } } template<class FuncType, class... ArgTypes> - inline void HandleFiCommand(FuncType func, void* context, Error* error, ArgTypes... args) { + inline void HandleRawFiCommand(void* context, Error* error, FuncType func, ArgTypes... args) { ssize_t ret; // Since handling timeouts is an overhead, we first try to send the data regularly - ret = func(args..., context); + ret = func(endpoint_, args..., context); if (ret == -FI_EAGAIN) { using namespace std::chrono; using clock = std::chrono::high_resolution_clock; - auto maxTime = clock::now() + milliseconds(requestTimeoutMs_); + auto maxTime = clock::now() + milliseconds(requestEnqueueTimeoutMs_); do { std::this_thread::sleep_for(milliseconds(3)); - ret = func(args..., context); + ret = func(endpoint_, args..., context); } while (ret == -FI_EAGAIN && maxTime >= clock::now()); } @@ -126,6 +191,7 @@ class FabricContextImpl : public FabricContext { } private: + bool TargetIsAliveCheck(FabricAddress address); void CompletionThread(); }; diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp new file mode 100644 index 000000000..6d6ce1b40 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp @@ -0,0 +1,28 @@ +#include <rdma/fi_tagged.h> +#include "fabric_alive_check_response_task.h" +#include "../fabric_context_impl.h" + +using namespace asapo; +using namespace fabric; + +void FabricAliveCheckResponseTask::RequeueSelf(FabricContextImpl* parentContext) { + Error tmpError = nullptr; + + parentContext->HandleRawFiCommand(this, &tmpError, + fi_trecv, nullptr, 0, nullptr, FI_ADDR_UNSPEC, FI_ASAPO_TAG_ALIVE_CHECK, kRecvTaggedExactMatch); + + // Error is ignored +} + +void FabricAliveCheckResponseTask::OnCompletion(const fi_cq_tagged_entry*, FabricAddress) { + // We received a ping, LibFabric will automatically notify the sender about the completion. +} + +void FabricAliveCheckResponseTask::OnErrorCompletion(const fi_cq_err_entry*) { + // Error is ignored +} + +FabricAliveCheckResponseTask::FabricAliveCheckResponseTask(FabricContextImpl* parentContext) : FabricSelfRequeuingTask( + parentContext) { + +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h new file mode 100644 index 000000000..4afc560dc --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h @@ -0,0 +1,25 @@ +#ifndef ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H +#define ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H + +#include "fabric_self_requeuing_task.h" + +namespace asapo { +namespace fabric { + +/** + * This is the counter part of FabricContextImpl.TargetIsAliveCheck + */ +class FabricAliveCheckResponseTask : public FabricSelfRequeuingTask { + public: + explicit FabricAliveCheckResponseTask(FabricContextImpl* parentContext); + protected: + void RequeueSelf(FabricContextImpl* parentContext) override; + + void OnCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; + + void OnErrorCompletion(const fi_cq_err_entry* errEntry) override; +}; +} +} + +#endif //ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp similarity index 71% rename from common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.cpp rename to common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp index b1c742b63..bef89058e 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.cpp +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp @@ -1,11 +1,10 @@ #include "fabric_self_deleting_task.h" -void asapo::fabric::FabricSelfDeletingTask::HandleCompletion(const fi_cq_tagged_entry* entry, - asapo::fabric::FabricAddress source) { +void asapo::fabric::FabricSelfDeletingTask::HandleCompletion(const fi_cq_tagged_entry*, FabricAddress) { OnDone(); } -void asapo::fabric::FabricSelfDeletingTask::HandleErrorCompletion(fi_cq_err_entry* errEntry) { +void asapo::fabric::FabricSelfDeletingTask::HandleErrorCompletion(const fi_cq_err_entry*) { OnDone(); } diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h similarity index 79% rename from common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.h rename to common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h index 1c10e0262..59d5f627a 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_self_deleting_task.h +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h @@ -1,7 +1,7 @@ #ifndef ASAPO_FABRIC_SELF_DELETING_TASK_H #define ASAPO_FABRIC_SELF_DELETING_TASK_H -#include "../../common/fabric_task.h" +#include "fabric_task.h" namespace asapo { namespace fabric { @@ -9,7 +9,7 @@ namespace fabric { class FabricSelfDeletingTask : FabricTask { void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) final; - void HandleErrorCompletion(fi_cq_err_entry* errEntry) final; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) final; private: virtual ~FabricSelfDeletingTask() = default; diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp new file mode 100644 index 000000000..1abebf8e0 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp @@ -0,0 +1,47 @@ +#include "fabric_self_requeuing_task.h" +#include "../fabric_context_impl.h" + +using namespace asapo; +using namespace fabric; + +FabricSelfRequeuingTask::~FabricSelfRequeuingTask() { + Stop(); +} + +FabricSelfRequeuingTask::FabricSelfRequeuingTask(FabricContextImpl* parentContext) { + parent_context_ = parentContext; +} + +void FabricSelfRequeuingTask::Start() { + if (was_queued_already_) { + throw std::runtime_error("FabricSelfRequeuingTask can only be queued once"); + } + RequeueSelf(parent_context_); +} + +void FabricSelfRequeuingTask::Stop() { + if (was_queued_already_ && still_running_) { + still_running_ = false; + printf("Going to stop FabricSelfRequeuingTask!!!"); + fi_cancel(&parent_context_->endpoint_->fid, this); + stop_response_future_.wait(); + } +} + +void FabricSelfRequeuingTask::HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) { + OnCompletion(entry, source); + AfterCompletion(); +} + +void FabricSelfRequeuingTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { + OnErrorCompletion(errEntry); + AfterCompletion(); +} + +void FabricSelfRequeuingTask::AfterCompletion() { + if (still_running_) { + RequeueSelf(parent_context_); + } else { + stop_response_.set_value(); + } +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h new file mode 100644 index 000000000..2a06b28be --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h @@ -0,0 +1,38 @@ +#ifndef ASAPO_FABRIC_SELF_REQUEUING_TASK_H +#define ASAPO_FABRIC_SELF_REQUEUING_TASK_H + +#include <future> +#include "fabric_task.h" + +namespace asapo { +namespace fabric { +class FabricContextImpl; + +class FabricSelfRequeuingTask : public FabricTask { + private: + FabricContextImpl* parent_context_; + volatile bool still_running_ = true; + bool was_queued_already_ = false; + std::promise<void> stop_response_; + std::future<void> stop_response_future_; + public: + ~FabricSelfRequeuingTask(); + explicit FabricSelfRequeuingTask(FabricContextImpl* parentContext); + + void Start(); + void Stop(); + public: + void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) final; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) final; + protected: + virtual void RequeueSelf(FabricContextImpl* parentContext) = 0; + virtual void OnCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) = 0; + virtual void OnErrorCompletion(const fi_cq_err_entry* errEntry) = 0; + private: + void AfterCompletion(); +}; +} +} + + +#endif //ASAPO_FABRIC_SELF_REQUEUING_TASK_H diff --git a/common/cpp/src/asapo_fabric/common/fabric_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_task.h similarity index 80% rename from common/cpp/src/asapo_fabric/common/fabric_task.h rename to common/cpp/src/asapo_fabric/common/task/fabric_task.h index 544900b5c..3802a558d 100644 --- a/common/cpp/src/asapo_fabric/common/fabric_task.h +++ b/common/cpp/src/asapo_fabric/common/task/fabric_task.h @@ -9,7 +9,7 @@ namespace fabric { class FabricTask { public: virtual void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) = 0; - virtual void HandleErrorCompletion(fi_cq_err_entry* errEntry) = 0; + virtual void HandleErrorCompletion(const fi_cq_err_entry* errEntry) = 0; }; } } diff --git a/common/cpp/src/asapo_fabric/common/fabric_waitable_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp similarity index 80% rename from common/cpp/src/asapo_fabric/common/fabric_waitable_task.cpp rename to common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp index 650ccc291..47efa2fe8 100644 --- a/common/cpp/src/asapo_fabric/common/fabric_waitable_task.cpp +++ b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp @@ -1,19 +1,19 @@ #include "fabric_waitable_task.h" -#include "../fabric_internal_error.h" +#include "../../fabric_internal_error.h" using namespace asapo; using namespace fabric; -FabricWaitableTask::FabricWaitableTask() : future_{promise_.get_future()} { +FabricWaitableTask::FabricWaitableTask() : future_{promise_.get_future()}, source_{FI_ADDR_NOTAVAIL} { } -void FabricWaitableTask::HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) { +void FabricWaitableTask::HandleCompletion(const fi_cq_tagged_entry*, FabricAddress source) { source_ = source; promise_.set_value(); } -void FabricWaitableTask::HandleErrorCompletion(fi_cq_err_entry* errEntry) { +void FabricWaitableTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { error_ = ErrorFromFabricInternal("FabricWaitableTask", -errEntry->err); promise_.set_value(); } diff --git a/common/cpp/src/asapo_fabric/common/fabric_waitable_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h similarity index 89% rename from common/cpp/src/asapo_fabric/common/fabric_waitable_task.h rename to common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h index 94eda0bd0..24a6b5659 100644 --- a/common/cpp/src/asapo_fabric/common/fabric_waitable_task.h +++ b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h @@ -19,7 +19,7 @@ class FabricWaitableTask : FabricTask { explicit FabricWaitableTask(); void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; - void HandleErrorCompletion(fi_cq_err_entry* errEntry) override; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) override; void Wait(uint32_t sleepInMs, Error* error); diff --git a/common/cpp/src/asapo_fabric/fabric_internal_error.cpp b/common/cpp/src/asapo_fabric/fabric_internal_error.cpp index d2f888454..fb8629e09 100644 --- a/common/cpp/src/asapo_fabric/fabric_internal_error.cpp +++ b/common/cpp/src/asapo_fabric/fabric_internal_error.cpp @@ -4,8 +4,11 @@ asapo::Error asapo::fabric::ErrorFromFabricInternal(const std::string& where, int internalStatusCode) { std::string errText = where + ": " + fi_strerror(-internalStatusCode); - if (internalStatusCode == -FI_ECANCELED) { + switch (-internalStatusCode) { + case FI_ECANCELED: return FabricErrorTemplates::kInternalOperationCanceledError.Generate(errText); + case FI_EIO: + return FabricErrorTemplates::kInternalConnectionError.Generate(errText); } return FabricErrorTemplates::kInternalError.Generate(errText); } diff --git a/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp b/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp index 0b9b08018..7051a97ce 100644 --- a/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp +++ b/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp @@ -41,9 +41,8 @@ FabricServerImpl::RdmaWrite(FabricAddress dstAddress, const MemoryRegionDetails* void FabricServerImpl::RecvAny(FabricAddress* srcAddress, FabricMessageId* messageId, void* dst, size_t size, Error* error) { FabricRecvAnyTask anyTask; - - HandleFiCommandAndWait(fi_trecv, &anyTask, error, - endpoint_, dst, size, nullptr, FI_ADDR_UNSPEC, 0, ~0ULL); + HandleFiCommandAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, &anyTask, error, + fi_trecv, dst, size, nullptr, FI_ADDR_UNSPEC, 0, kRecvTaggedAnyMatch); if (!(*error)) { if (anyTask.GetSource() == FI_ADDR_NOTAVAIL) { diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp index 565db79c6..f811a2da5 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp +++ b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp @@ -1,7 +1,7 @@ #include <rdma/fi_endpoint.h> #include "fabric_handshake_accepting_task.h" #include "../fabric_server_impl.h" -#include "fabric_self_deleting_task.h" +#include "../../common/task/fabric_self_deleting_task.h" using namespace asapo; using namespace fabric; @@ -13,7 +13,7 @@ FabricHandshakeAcceptingTask::~FabricHandshakeAcceptingTask() { FabricHandshakeAcceptingTask::FabricHandshakeAcceptingTask(FabricServerImpl* server) : server_{server} { } -void FabricHandshakeAcceptingTask::HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) { +void FabricHandshakeAcceptingTask::HandleCompletion(const fi_cq_tagged_entry*, FabricAddress) { Error error; HandleAccept(&error); if (error) { @@ -23,7 +23,7 @@ void FabricHandshakeAcceptingTask::HandleCompletion(const fi_cq_tagged_entry* en StartRequest(); } -void FabricHandshakeAcceptingTask::HandleErrorCompletion(fi_cq_err_entry* errEntry) { +void FabricHandshakeAcceptingTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { Error error; error = ErrorFromFabricInternal("FabricWaitableTask", -errEntry->err); OnError(&error); @@ -34,9 +34,8 @@ void FabricHandshakeAcceptingTask::HandleErrorCompletion(fi_cq_err_entry* errEnt void FabricHandshakeAcceptingTask::StartRequest() { if (server_->accepting_task_running) { Error error; - server_->HandleFiCommand(fi_recv, this, &error, - server_->endpoint_, &handshake_payload_, sizeof(handshake_payload_), - nullptr, FI_ADDR_UNSPEC); + server_->HandleRawFiCommand(this, &error, + fi_recv, &handshake_payload_, sizeof(handshake_payload_), nullptr, FI_ADDR_UNSPEC); if (error) { OnError(&error); @@ -74,8 +73,8 @@ void FabricHandshakeAcceptingTask::HandleAccept(Error* error) { // TODO: This could slow down the whole complete queue process, maybe use another thread? :/ // Send and forget - server_->HandleFiCommand(fi_send, new FabricSelfDeletingTask(), error, - server_->endpoint_, nullptr, 0, nullptr, tmpAddr); + server_->HandleRawFiCommand(new FabricSelfDeletingTask(), error, + fi_send, nullptr, 0, nullptr, tmpAddr); if (*error) { return; } diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h index b79c16ae6..5b4f045ea 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h +++ b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h @@ -1,7 +1,7 @@ #ifndef ASAPO_FABRIC_HANDSHAKE_ACCEPTING_TASK_H #define ASAPO_FABRIC_HANDSHAKE_ACCEPTING_TASK_H -#include "../../common/fabric_task.h" +#include "../../common/task/fabric_task.h" #include "../../common/fabric_context_impl.h" namespace asapo { @@ -23,7 +23,7 @@ class FabricHandshakeAcceptingTask : public FabricTask { explicit FabricHandshakeAcceptingTask(FabricServerImpl* server); void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; - void HandleErrorCompletion(fi_cq_err_entry* errEntry) override; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) override; void StartRequest(); void DeleteRequest(); diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp index 3666db205..6e703a43e 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp +++ b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp @@ -8,7 +8,7 @@ void FabricRecvAnyTask::HandleCompletion(const fi_cq_tagged_entry* entry, Fabric FabricWaitableTask::HandleCompletion(entry, source); } -void FabricRecvAnyTask::HandleErrorCompletion(fi_cq_err_entry* errEntry) { +void FabricRecvAnyTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { messageId_ = errEntry->tag; FabricWaitableTask::HandleErrorCompletion(errEntry); } diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h index 9221a6ebd..066310824 100644 --- a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h +++ b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h @@ -3,7 +3,7 @@ #include <asapo_fabric/asapo_fabric.h> #include <rdma/fi_eq.h> -#include "../../common/fabric_waitable_task.h" +#include "../../common/task/fabric_waitable_task.h" namespace asapo { namespace fabric { @@ -13,7 +13,7 @@ class FabricRecvAnyTask : public FabricWaitableTask { FabricMessageId messageId_; public: void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; - void HandleErrorCompletion(fi_cq_err_entry* errEntry) override; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) override; FabricMessageId GetMessageId() const; }; diff --git a/tests/automatic/asapo_fabric/client_lazy_initialization.cpp b/tests/automatic/asapo_fabric/client_lazy_initialization.cpp index 0e7749c23..751385618 100644 --- a/tests/automatic/asapo_fabric/client_lazy_initialization.cpp +++ b/tests/automatic/asapo_fabric/client_lazy_initialization.cpp @@ -14,9 +14,10 @@ int main(int argc, char* argv[]) { M_AssertEq("", client->GetAddress()); - int dummyBuffer; + int dummyBuffer = 0; auto mr = client->ShareMemoryRegion(&dummyBuffer, sizeof(dummyBuffer), &err); M_AssertEq(FabricErrorTemplates::kClientNotInitializedError, err, "client->ShareMemoryRegion"); + err = nullptr; // Other methods require an serverAddress which initializes the client diff --git a/tests/automatic/asapo_fabric/parallel_data_transfer.cpp b/tests/automatic/asapo_fabric/parallel_data_transfer.cpp index 813e3b5fa..fc99295ef 100644 --- a/tests/automatic/asapo_fabric/parallel_data_transfer.cpp +++ b/tests/automatic/asapo_fabric/parallel_data_transfer.cpp @@ -48,12 +48,12 @@ void ServerChildThread(FabricServer* server, std::atomic<int>* serverTotalReques std::cerr << "A Server is done" << std::endl; } -void ServerMasterThread(char* expectedRdmaBuffer) { +void ServerMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { Error err; auto log = CreateDefaultLoggerBin("AutomaticTesting"); auto factory = GenerateDefaultFabricFactory(); - auto server = factory->CreateAndBindServer(log.get(), "127.0.0.1", 1816, &err); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); std::atomic<int> serverTotalRequests(0); @@ -71,14 +71,14 @@ void ServerMasterThread(char* expectedRdmaBuffer) { serverIsDone.set_value(); } -void ClientChildThread(int index, char* expectedRdmaBuffer) { +void ClientChildThread(const std::string& hostname, uint16_t port, int index, char* expectedRdmaBuffer) { auto factory = GenerateDefaultFabricFactory(); Error err; auto client = factory->CreateClient(&err); M_AssertEq(nullptr, err, "factory->CreateClient"); - auto serverAddress = client->AddServerAddress("127.0.0.1:1816", &err); + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); M_AssertEq(nullptr, err, "client->AddServerAddress"); auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); @@ -114,10 +114,10 @@ void ClientChildThread(int index, char* expectedRdmaBuffer) { std::cout << "A Client is done" << std::endl; } -void ClientMasterThread(char* expectedRdmaBuffer) { +void ClientMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { std::thread threads[kClientThreads]; for (int i = 0; i < kClientThreads; i++) { - threads[i] = std::thread(ClientChildThread, i, expectedRdmaBuffer); + threads[i] = std::thread(ClientChildThread, hostname, port, i, expectedRdmaBuffer); } for (auto& thread : threads) { @@ -130,6 +130,19 @@ void ClientMasterThread(char* expectedRdmaBuffer) { } int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } std::cout << "Client is writing to std::cout" << std::endl; std::cerr << "Server is writing to std::cerr" << std::endl; @@ -139,10 +152,10 @@ int main(int argc, char* argv[]) { expectedRdmaBuffer[i] = (char)i; } - std::thread serverMasterThread(ServerMasterThread, expectedRdmaBuffer.get()); + std::thread serverMasterThread(ServerMasterThread, hostname, port, expectedRdmaBuffer.get()); std::this_thread::sleep_for(std::chrono::seconds(2)); - ClientMasterThread(expectedRdmaBuffer.get()); + ClientMasterThread(hostname, port, expectedRdmaBuffer.get()); std::cout << "Done testing. Joining server" << std::endl; serverMasterThread.join(); diff --git a/tests/automatic/asapo_fabric/server_not_running.cpp b/tests/automatic/asapo_fabric/server_not_running.cpp index cf971685f..aa11d6936 100644 --- a/tests/automatic/asapo_fabric/server_not_running.cpp +++ b/tests/automatic/asapo_fabric/server_not_running.cpp @@ -1,19 +1,35 @@ #include <common/error.h> #include <asapo_fabric/asapo_fabric.h> #include <testing.h> +#include <iostream> using namespace asapo; using namespace fabric; int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + Error err; auto factory = GenerateDefaultFabricFactory(); auto client = factory->CreateClient(&err); M_AssertEq(nullptr, err, "factory->CreateClient"); - auto serverAddress = client->AddServerAddress("127.0.0.1:1234", &err); + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); M_AssertEq(FabricErrorTemplates::kConnectionRefusedError, err, "client->AddServerAddress"); + err = nullptr; return 0; } diff --git a/tests/automatic/asapo_fabric/simple_data_transfer.cpp b/tests/automatic/asapo_fabric/simple_data_transfer.cpp index 860410359..5c35f5a2c 100644 --- a/tests/automatic/asapo_fabric/simple_data_transfer.cpp +++ b/tests/automatic/asapo_fabric/simple_data_transfer.cpp @@ -20,42 +20,45 @@ constexpr int kTotalRuns = 3; constexpr int kEachInstanceRuns = 5; constexpr size_t kRdmaSize = 5 * 1024 * 1024; -void ServerMasterThread(char* expectedRdmaBuffer) { - Error err; - auto log = CreateDefaultLoggerBin("AutomaticTesting"); - - auto factory = GenerateDefaultFabricFactory(); - auto server = factory->CreateAndBindServer(log.get(), "127.0.0.1", 1816, &err); - M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); - - for (int run = 0; run < kTotalRuns; run++) { - for (int instanceRuns = 0; instanceRuns < kEachInstanceRuns; instanceRuns++) { - GenericRequestHeader request{}; - - FabricAddress clientAddress; - FabricMessageId messageId; - server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); - M_AssertEq(nullptr, err, "server->RecvAny"); - M_AssertEq(123 + instanceRuns, messageId); - M_AssertEq("Hello World", request.message); - - server->RdmaWrite(clientAddress, (MemoryRegionDetails*) &request.substream, expectedRdmaBuffer, kRdmaSize, - &err); - M_AssertEq(nullptr, err, "server->RdmaWrite"); - - GenericNetworkResponse response{}; - strcpy(response.message, "Hey, I am the Server"); - server->Send(clientAddress, messageId, &response, sizeof(response), &err); - M_AssertEq(nullptr, err, "server->Send"); +void ServerMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { + { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + + for (int run = 0; run < kTotalRuns; run++) { + for (int instanceRuns = 0; instanceRuns < kEachInstanceRuns; instanceRuns++) { + GenericRequestHeader request{}; + + FabricAddress clientAddress; + FabricMessageId messageId; + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "server->RecvAny"); + M_AssertEq(123 + instanceRuns, messageId); + M_AssertEq("Hello World", request.message); + + server->RdmaWrite(clientAddress, (MemoryRegionDetails*) &request.substream, expectedRdmaBuffer, kRdmaSize, + &err); + M_AssertEq(nullptr, err, "server->RdmaWrite"); + + GenericNetworkResponse response{}; + strcpy(response.message, "Hey, I am the Server"); + server->Send(clientAddress, messageId, &response, sizeof(response), &err); + M_AssertEq(nullptr, err, "server->Send"); + } } - } - std::cout << "[SERVER] Waiting for client to finish" << std::endl; - clientIsDoneFuture.get(); + std::cout << "[SERVER] Waiting for client to finish" << std::endl; + clientIsDoneFuture.get(); + } + std::cout << "[SERVER] Server is done" << std::endl; serverIsDone.set_value(); } -void ClientThread(char* expectedRdmaBuffer) { +void ClientThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { Error err; for (int run = 0; run < kTotalRuns; run++) { @@ -66,7 +69,7 @@ void ClientThread(char* expectedRdmaBuffer) { auto client = factory->CreateClient(&err); M_AssertEq(nullptr, err, "factory->CreateClient"); - auto serverAddress = client->AddServerAddress("127.0.0.1:1816", &err); + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); M_AssertEq(nullptr, err, "client->AddServerAddress"); auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); @@ -95,20 +98,35 @@ void ClientThread(char* expectedRdmaBuffer) { } } } + clientIsDone.set_value(); serverIsDoneFuture.get(); } int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + auto expectedRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); for (size_t i = 0; i < kRdmaSize; i++) { expectedRdmaBuffer[i] = (char)i; } - std::thread serverThread(ServerMasterThread, expectedRdmaBuffer.get()); + std::thread serverThread(ServerMasterThread, hostname, port, expectedRdmaBuffer.get()); std::this_thread::sleep_for(std::chrono::seconds(2)); - ClientThread(expectedRdmaBuffer.get()); + ClientThread(hostname, port, expectedRdmaBuffer.get()); std::cout << "Done testing. Joining server" << std::endl; serverThread.join(); diff --git a/tests/automatic/asapo_fabric/timeout_test.cpp b/tests/automatic/asapo_fabric/timeout_test.cpp index d4b616ce1..42abd6408 100644 --- a/tests/automatic/asapo_fabric/timeout_test.cpp +++ b/tests/automatic/asapo_fabric/timeout_test.cpp @@ -8,24 +8,40 @@ using namespace asapo; using namespace fabric; -void ServerMasterThread() { - Error err; - auto log = CreateDefaultLoggerBin("AutomaticTesting"); +std::promise<void> serverShutdown; +std::future<void> serverShutdown_future = serverShutdown.get_future(); - auto factory = GenerateDefaultFabricFactory(); +std::promise<void> serverShutdownAck; +std::future<void> serverShutdownAck_future = serverShutdownAck.get_future(); + +void ServerMasterThread(const std::string& hostname, uint16_t port) { + { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); - auto server = factory->CreateAndBindServer(log.get(), "127.0.0.1", 1816, &err); - M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); - // Wait for client to send a request and then shutdown the server - int dummyBuffer; - FabricAddress clientAddress; - FabricMessageId messageId; - server->RecvAny(&clientAddress, &messageId, &dummyBuffer, sizeof(dummyBuffer), &err); - M_AssertEq(nullptr, err, "server->RecvAny"); + // Wait for client to send a request and then shutdown the server + int dummyBuffer; + FabricAddress clientAddress; + FabricMessageId messageId; + server->RecvAny(&clientAddress, &messageId, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "server->RecvAny"); + + server->Send(clientAddress, messageId, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "server->Send"); + + serverShutdown_future.wait(); + } + + printf("Server is now down!\n"); + serverShutdownAck.set_value(); } -void ClientThread() { +void ClientThread(const std::string& hostname, uint16_t port) { Error err; auto factory = GenerateDefaultFabricFactory(); @@ -33,23 +49,56 @@ void ClientThread() { auto client = factory->CreateClient(&err); M_AssertEq(nullptr, err, "factory->CreateClient"); - auto serverAddress = client->AddServerAddress("127.0.0.1:1816", &err); + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); M_AssertEq(nullptr, err, "client->AddServerAddress"); - int dummyBuffer; - client->Send(serverAddress, 1, &dummyBuffer, sizeof(dummyBuffer), &err); + int dummyBuffer = 0; + client->Send(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); M_AssertEq(nullptr, err, "client->Send"); - // The server should shut down now! - client->Recv(serverAddress, 1, &dummyBuffer, sizeof(dummyBuffer), &err); + client->Recv(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "client->Recv"); + + // Server should not respond to this message + std::cout << + "The following call might take a while since its able to reach the server but the server is not responding" + << std::endl; + client->Recv(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); M_AssertEq(FabricErrorTemplates::kTimeout, err, "client->Recv"); + err = nullptr; + + serverShutdown.set_value(); + serverShutdownAck_future.wait(); + + // Server is now down + client->Recv(serverAddress, 1, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kInternalConnectionError, err, "client->Recv"); + err = nullptr; + + client->Send(serverAddress, 2, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kInternalConnectionError, err, "client->Send"); + err = nullptr; } int main(int argc, char* argv[]) { - std::thread serverThread(ServerMasterThread); + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + std::thread serverThread(ServerMasterThread, hostname, port); std::this_thread::sleep_for(std::chrono::seconds(2)); - ClientThread(); + ClientThread(hostname, port); std::cout << "Done testing. Joining server" << std::endl; serverThread.join(); diff --git a/tests/automatic/asapo_fabric/wrong_memory_info.cpp b/tests/automatic/asapo_fabric/wrong_memory_info.cpp index bbf67f27f..f94fc6c65 100644 --- a/tests/automatic/asapo_fabric/wrong_memory_info.cpp +++ b/tests/automatic/asapo_fabric/wrong_memory_info.cpp @@ -19,12 +19,12 @@ std::future<void> serverIsDoneFuture = serverIsDone.get_future(); constexpr size_t kRdmaSize = 5 * 1024; constexpr size_t kDummyDataSize = 512; -void ServerMasterThread() { +void ServerMasterThread(const std::string& hostname, uint16_t port) { Error err; auto log = CreateDefaultLoggerBin("AutomaticTesting"); auto factory = GenerateDefaultFabricFactory(); - auto server = factory->CreateAndBindServer(log.get(), "127.0.0.1", 1816, &err); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); GenericRequestHeader request{}; @@ -68,7 +68,7 @@ void ServerMasterThread() { serverIsDone.set_value(); } -void ClientThread() { +void ClientThread(const std::string& hostname, uint16_t port) { Error err; auto factory = GenerateDefaultFabricFactory(); @@ -76,7 +76,7 @@ void ClientThread() { auto client = factory->CreateClient(&err); M_AssertEq(nullptr, err, "factory->CreateClient"); - auto serverAddress = client->AddServerAddress("127.0.0.1:1816", &err); + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); M_AssertEq(nullptr, err, "client->AddServerAddress"); auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); @@ -120,10 +120,24 @@ void ClientThread() { } int main(int argc, char* argv[]) { - std::thread serverThread(ServerMasterThread); + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + std::thread serverThread(ServerMasterThread, hostname, port); std::this_thread::sleep_for(std::chrono::seconds(2)); - ClientThread(); + ClientThread(hostname, port); std::cout << "Done testing. Joining server" << std::endl; serverThread.join(); -- GitLab