diff --git a/CMakeLists.txt b/CMakeLists.txt index 9eca15771bdc8a9f261a509e131bf3b21bb91a21..f444089d0209112a31fa486ea7da8003692bac36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,10 +17,11 @@ ENDIF(WIN32) #TODO: Better way then GLOBAL PROPERTY IF(WIN32) find_package(Threads REQUIRED) - SET_PROPERTY(GLOBAL PROPERTY ASAPO_COMMON_IO_LIBRARIES ${CMAKE_THREAD_LIBS_INIT} wsock32 ws2_32) + SET(ASAPO_COMMON_IO_LIBRARIES ${CMAKE_THREAD_LIBS_INIT} wsock32 ws2_32) ELSEIF(UNIX) - SET_PROPERTY(GLOBAL PROPERTY ASAPO_COMMON_IO_LIBRARIES Threads::Threads) + SET(ASAPO_COMMON_IO_LIBRARIES Threads::Threads) ENDIF(WIN32) +SET_PROPERTY(GLOBAL PROPERTY ASAPO_COMMON_IO_LIBRARIES ${ASAPO_COMMON_IO_LIBRARIES}) if (CMAKE_BUILD_TYPE STREQUAL "Debug") add_definitions(-DUNIT_TESTS) @@ -41,6 +42,8 @@ option(BUILD_PYTHON_DOCS "Uses sphinx to build the Python documentaion" OFF) option(BUILD_CONSUMER_TOOLS "Build consumer tools" OFF) option(BUILD_EXAMPLES "Build examples" OFF) +option(ENABLE_LIBFABRIC "Enables LibFabric support for RDMA transfers" OFF) + set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/CMakeModules/) set (ASAPO_CXX_COMMON_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common/cpp/include) @@ -62,6 +65,18 @@ endif() message (STATUS "Using Python: ${Python_EXECUTABLE}") +SET(ASAPO_COMMON_FABRIC_LIBRARIES ${ASAPO_COMMON_IO_LIBRARIES}) +IF(ENABLE_LIBFABRIC) + find_package(LibFabric) + if(NOT LIBFABRIC_LIBRARY) + message(FATAL_ERROR "Did not find libfabric") + endif() + message(STATUS "LibFabric support enabled") + message(STATUS "LIB_FABRIC: Path: ${LIBFABRIC_LIBRARY} Include: ${LIBFABRIC_INCLUDE_DIR}") + add_definitions(-DLIBFABRIC_ENABLED) + SET(ASAPO_COMMON_FABRIC_LIBRARIES ${ASAPO_COMMON_FABRIC_LIBRARIES} fabric) +ENDIF() +SET_PROPERTY(GLOBAL PROPERTY ASAPO_COMMON_FABRIC_LIBRARIES ${ASAPO_COMMON_FABRIC_LIBRARIES}) # format sources include(astyle) diff --git a/CMakeModules/FindLibFabric.cmake b/CMakeModules/FindLibFabric.cmake new file mode 100644 index 0000000000000000000000000000000000000000..24b54d5b1991820ae24ae3777d05c5ba5e98c41f --- /dev/null +++ b/CMakeModules/FindLibFabric.cmake @@ -0,0 +1,15 @@ +# FindLibFabric +# ------------- +# +# Tries to find LibFabric on the system +# +# Available variables +# LIBFABRIC_LIBRARY - Path to the library +# LIBFABRIC_INCLUDE_DIR - Path to the include dir + +cmake_minimum_required(VERSION 2.6) + +find_path(LIBFABRIC_INCLUDE_DIR fabric.h) +find_library(LIBFABRIC_LIBRARY fabric) + +mark_as_advanced(LIBFABRIC_INCLUDE_DIR LIBFABRIC_LIBRARY) diff --git a/common/cpp/CMakeLists.txt b/common/cpp/CMakeLists.txt index d7770ef6ae91190d5033aca819c860dfbf6d58d0..cfa9198b4acab97307960a1b03f3aae701f094a1 100644 --- a/common/cpp/CMakeLists.txt +++ b/common/cpp/CMakeLists.txt @@ -12,6 +12,8 @@ add_subdirectory(src/logger) add_subdirectory(src/request) +add_subdirectory(src/asapo_fabric) + if(BUILD_MONGODB_CLIENTLIB) add_subdirectory(src/database) endif() diff --git a/common/cpp/include/asapo_fabric/asapo_fabric.h b/common/cpp/include/asapo_fabric/asapo_fabric.h new file mode 100644 index 0000000000000000000000000000000000000000..a9b9f8da1e9729d4b636dcc92ec804ce4f1ba926 --- /dev/null +++ b/common/cpp/include/asapo_fabric/asapo_fabric.h @@ -0,0 +1,85 @@ +#ifndef ASAPO_FABRIC_H +#define ASAPO_FABRIC_H + +#include <cstdint> +#include <string> +#include <memory> +#include <common/error.h> +#include <logger/logger.h> +#include "fabric_error.h" + +namespace asapo { +namespace fabric { +typedef uint64_t FabricAddress; +typedef uint64_t FabricMessageId; + +// TODO Use a serialization framework +struct MemoryRegionDetails { + uint64_t addr; + uint64_t length; + uint64_t key; +}; + +class FabricMemoryRegion { + public: + virtual ~FabricMemoryRegion() = default; + virtual const MemoryRegionDetails* GetDetails() const = 0; +}; + +class FabricContext { + public: + virtual std::string GetAddress() const = 0; + + virtual std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) = 0; + + virtual void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) = 0; + + virtual void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) = 0; + + virtual void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) = 0; + + // Since RdmaRead heavily impacts the performance we will not implement this + // virtual void RdmaRead(...) = 0; + + +}; + +class FabricClient : public FabricContext { + public: + virtual ~FabricClient() = default; + + /// The serverAddress must be in this format: "hostname:port" + virtual FabricAddress AddServerAddress(const std::string& serverAddress, Error* error) = 0; +}; + +class FabricServer : public FabricContext { + public: + virtual ~FabricServer() = default; + + virtual void RecvAny(FabricAddress* srcAddress, FabricMessageId* messageId, void* dst, size_t size, Error* error) = 0; +}; + +class FabricFactory { + public: + /** + * Creates a new server and will immediately allocate and listen to the given host:port + */ + virtual std::unique_ptr<FabricServer> + CreateAndBindServer(const AbstractLogger* logger, const std::string& host, uint16_t port, + Error* error) const = 0; + + /** + * Will allocate a proper domain as soon as the client gets his first server address added + */ + virtual std::unique_ptr<FabricClient> CreateClient(Error* error) const = 0; +}; + +std::unique_ptr<FabricFactory> GenerateDefaultFabricFactory(); +} +} + +#endif //ASAPO_FABRIC_H diff --git a/common/cpp/include/asapo_fabric/fabric_error.h b/common/cpp/include/asapo_fabric/fabric_error.h new file mode 100644 index 0000000000000000000000000000000000000000..854d677f4b4d0773038eb183fead83323148f350 --- /dev/null +++ b/common/cpp/include/asapo_fabric/fabric_error.h @@ -0,0 +1,58 @@ +#ifndef ASAPO_FABRIC_ERROR_H +#define ASAPO_FABRIC_ERROR_H + +namespace asapo { +namespace fabric { +enum class FabricErrorType { + kNotSupported, + kOutdatedLibrary, + kInternalError, // An error that was produced by LibFabric + kInternalOperationCanceled, // An error that was produced by LibFabric + kInternalConnectionError, // This might occur when the connection is unexpectedly closed + kNoDeviceFound, + kClientNotInitialized, + kTimeout, + kConnectionRefused, +}; + + +using FabricError = ServiceError<FabricErrorType, ErrorType::kFabricError>; +using FabricErrorTemplate = ServiceErrorTemplate<FabricErrorType, ErrorType::kFabricError>; + +namespace FabricErrorTemplates { +auto const kNotSupportedOnBuildError = FabricErrorTemplate { + "This build of ASAPO does not support LibFabric", FabricErrorType::kNotSupported +}; +auto const kOutdatedLibraryError = FabricErrorTemplate { + "LibFabric outdated", FabricErrorType::kOutdatedLibrary +}; +auto const kInternalError = FabricErrorTemplate { + "Internal LibFabric error", FabricErrorType::kInternalError +}; +auto const kInternalOperationCanceledError = FabricErrorTemplate { + "Internal LibFabric operation canceled error", FabricErrorType::kInternalOperationCanceled +}; +auto const kNoDeviceFoundError = FabricErrorTemplate { + "No device was found (Check your config)", FabricErrorType::kNoDeviceFound +}; +auto const kClientNotInitializedError = FabricErrorTemplate { + "The client was not initialized. Add server address first!", + FabricErrorType::kClientNotInitialized +}; +auto const kTimeout = FabricErrorTemplate { + "Timeout", + FabricErrorType::kTimeout +}; +auto const kConnectionRefusedError = FabricErrorTemplate { + "Connection refused", + FabricErrorType::kConnectionRefused +}; +auto const kInternalConnectionError = FabricErrorTemplate { + "Connection error (maybe a disconnect?)", + FabricErrorType::kInternalConnectionError +}; +} + +} +} +#endif //ASAPO_FABRIC_ERROR_H diff --git a/common/cpp/include/common/error.h b/common/cpp/include/common/error.h index e2a49f858bf78b1961948d2ae73dbfec85205402..c2259b79b551f0b3a1449ce30a55443aa75951da 100644 --- a/common/cpp/include/common/error.h +++ b/common/cpp/include/common/error.h @@ -18,6 +18,7 @@ enum class ErrorType { kConsumerError, kMemoryAllocationError, kEndOfFile, + kFabricError, }; class ErrorInterface; @@ -213,7 +214,7 @@ class ServiceErrorTemplate : public SimpleErrorTemplate { } inline Error Generate(const std::string& suffix) const noexcept override { - return Error(new ServiceError<ServiceErrorType, MainErrorType>(error_ + " :" + suffix, error_type_)); + return Error(new ServiceError<ServiceErrorType, MainErrorType>(error_ + ": " + suffix, error_type_)); } inline bool operator==(const Error& rhs) const override { diff --git a/common/cpp/include/common/networking.h b/common/cpp/include/common/networking.h index 5348d85a9414cd3642edf9400f897496dc4c4e2a..d1c79909bd44b9fe83f7833ffe461c57d8c05d99 100644 --- a/common/cpp/include/common/networking.h +++ b/common/cpp/include/common/networking.h @@ -47,13 +47,13 @@ struct GenericRequestHeader { uint64_t i_data_size = 0, uint64_t i_meta_size = 0, const std::string& i_message = "", const std::string& i_substream = ""): op_code{i_op_code}, data_id{i_data_id}, data_size{i_data_size}, meta_size{i_meta_size} { - strncpy(message, i_message.c_str(), kMaxMessageSize); + strncpy(message, i_message.c_str(), kMaxMessageSize); // TODO must be memcpy in order to send raw MemoryDetails strncpy(substream, i_substream.c_str(), kMaxMessageSize); } GenericRequestHeader(const GenericRequestHeader& header) { op_code = header.op_code, data_id = header.data_id, data_size = header.data_size, meta_size = header.meta_size, memcpy(custom_data, header.custom_data, kNCustomParams * sizeof(uint64_t)), - strncpy(message, header.message, kMaxMessageSize); + strncpy(message, header.message, kMaxMessageSize); // TODO must be memcpy in order to send raw MemoryDetails strncpy(substream, header.substream, kMaxMessageSize); } diff --git a/common/cpp/include/io/io.h b/common/cpp/include/io/io.h index dade78385798a06a20ba647f009d2f0b7a5c153b..25266e88a991015ac084b1a7fba56090fbc050dd 100644 --- a/common/cpp/include/io/io.h +++ b/common/cpp/include/io/io.h @@ -92,6 +92,10 @@ class IO { * @param err Since CloseSocket if often used in an error case, it's able to accept err as nullptr. */ virtual void CloseSocket(SocketDescriptor socket_fd, Error* err) const = 0; + virtual std::string AddressFromSocket(SocketDescriptor socket) const noexcept = 0; + virtual std::string GetHostName(Error* err) const noexcept = 0; + virtual std::unique_ptr<std::tuple<std::string, uint16_t>> SplitAddressToHostnameAndPort( + const std::string& address) const = 0; /* * Filesystem @@ -117,8 +121,6 @@ class IO { virtual std::vector<FileInfo> FilesInFolder (const std::string& folder, Error* err) const = 0; virtual std::string ReadFileToString (const std::string& fname, Error* err) const = 0; virtual Error GetLastError() const = 0; - virtual std::string AddressFromSocket(SocketDescriptor socket) const noexcept = 0; - virtual std::string GetHostName(Error* err) const noexcept = 0; virtual FileInfo GetFileInfo(const std::string& name, Error* err) const = 0; virtual ~IO() = default; diff --git a/common/cpp/include/unittests/MockFabric.h b/common/cpp/include/unittests/MockFabric.h new file mode 100644 index 0000000000000000000000000000000000000000..f433814860e1eaaa1f566a3413242db0d9762a3b --- /dev/null +++ b/common/cpp/include/unittests/MockFabric.h @@ -0,0 +1,99 @@ +#ifndef ASAPO_MOCKFABRIC_H +#define ASAPO_MOCKFABRIC_H + +#include <asapo_fabric/asapo_fabric.h> + +namespace asapo { +namespace fabric { + +class MockFabricMemoryRegion : public FabricMemoryRegion { + MOCK_CONST_METHOD0(GetDetails, const MemoryRegionDetails * ()); +}; + +class MockFabricContext : public FabricContext { + MOCK_CONST_METHOD0(GetAddress, std::string()); + + std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) override { + ErrorInterface* err = nullptr; + auto data = ShareMemoryRegion_t(src, size, &err); + error->reset(err); + return std::unique_ptr<FabricMemoryRegion> {data}; + } + MOCK_METHOD3(ShareMemoryRegion_t, FabricMemoryRegion * (void* src, size_t size, ErrorInterface** err)); + + void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) override { + ErrorInterface* err = nullptr; + Send_t(dstAddress, messageId, src, size, &err); + error->reset(err); + } + MOCK_METHOD5(Send_t, void(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, ErrorInterface** err)); + + void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) override { + ErrorInterface* err = nullptr; + Recv_t(srcAddress, messageId, dst, size, &err); + error->reset(err); + } + MOCK_METHOD5(Recv_t, void(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, ErrorInterface** err)); + + void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) override { + ErrorInterface* err = nullptr; + RdmaWrite_t(dstAddress, details, buffer, size, &err); + error->reset(err); + } + MOCK_METHOD5(RdmaWrite_t, void(FabricAddress dstAddress, const MemoryRegionDetails* details, const void* buffer, + size_t size, ErrorInterface** error)); +}; + +class MockFabricClient : public MockFabricContext, public FabricClient { + FabricAddress AddServerAddress(const std::string& serverAddress, Error* error) override { + ErrorInterface* err = nullptr; + auto data = AddServerAddress_t(serverAddress, &err); + error->reset(err); + return data; + } + MOCK_METHOD2(AddServerAddress_t, FabricAddress (const std::string& serverAddress, ErrorInterface** err)); +}; + +class MockFabricServer : public MockFabricContext, public FabricServer { + void RecvAny(FabricAddress* srcAddress, FabricMessageId* messageId, void* dst, size_t size, Error* error) override { + ErrorInterface* err = nullptr; + RecvAny_t(srcAddress, messageId, dst, size, &err); + error->reset(err); + } + MOCK_METHOD5(RecvAny_t, void(FabricAddress* srcAddress, FabricMessageId* messageId, + void* dst, size_t size, ErrorInterface** err)); +}; + +class MockFabricFactory : public FabricFactory { + public: + std::unique_ptr<FabricServer> + CreateAndBindServer(const AbstractLogger* logger, const std::string& host, uint16_t port, + Error* error) const override { + ErrorInterface* err = nullptr; + auto data = CreateAndBindServer_t(logger, host, port, &err); + error->reset(err); + return std::unique_ptr<FabricServer> {data}; + } + MOCK_CONST_METHOD4(CreateAndBindServer_t, + FabricServer * (const AbstractLogger* logger, const std::string& host, + uint16_t port, ErrorInterface** err)); + + std::unique_ptr<FabricClient> CreateClient(Error* error) const override { + ErrorInterface* err = nullptr; + auto data = CreateClient_t(&err); + error->reset(err); + return std::unique_ptr<FabricClient> {data}; + } + MOCK_CONST_METHOD1(CreateClient_t, + FabricClient * (ErrorInterface** err)); +}; +} +} + +#endif //ASAPO_MOCKFABRIC_H diff --git a/common/cpp/include/unittests/MockIO.h b/common/cpp/include/unittests/MockIO.h index 90204e58c4e4909761e74751214171189711f6ad..3eb3e0c8b2d59c45b8a55871dcdbb8232da1cfe4 100644 --- a/common/cpp/include/unittests/MockIO.h +++ b/common/cpp/include/unittests/MockIO.h @@ -145,6 +145,10 @@ class MockIO : public IO { } MOCK_CONST_METHOD4(Send_t, size_t(SocketDescriptor socket_fd, const void* buf, size_t length, ErrorInterface** err)); + + MOCK_CONST_METHOD1(SplitAddressToHostnameAndPort, + std::unique_ptr<std::tuple<std::string, uint16_t>>(const std::string& address)); + void Skip(SocketDescriptor socket_fd, size_t length, Error* err) const override { ErrorInterface* error = nullptr; Skip_t(socket_fd, length, &error); diff --git a/common/cpp/src/asapo_fabric/CMakeLists.txt b/common/cpp/src/asapo_fabric/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fefef76399ebc6e0d9cc04e2335d7fb8e84663de --- /dev/null +++ b/common/cpp/src/asapo_fabric/CMakeLists.txt @@ -0,0 +1,34 @@ +set(TARGET_NAME asapo-fabric) + +include_directories(include) + +set(SOURCE_FILES asapo_fabric.cpp) + +IF(ENABLE_LIBFABRIC) + set(SOURCE_FILES ${SOURCE_FILES} + fabric_internal_error.cpp + fabric_factory_impl.cpp + common/fabric_context_impl.cpp + common/fabric_memory_region_impl.cpp + common/task/fabric_waitable_task.cpp + common/task/fabric_self_deleting_task.cpp + common/task/fabric_self_requeuing_task.cpp + common/task/fabric_alive_check_response_task.cpp + client/fabric_client_impl.cpp + server/fabric_server_impl.cpp + server/task/fabric_recv_any_task.cpp + server/task/fabric_handshake_accepting_task.cpp + ) +ELSE() + set(SOURCE_FILES ${SOURCE_FILES} + fabric_factory_not_supported.cpp + ) +ENDIF() + +################################ +# Library +################################ + +add_library(${TARGET_NAME} STATIC ${SOURCE_FILES} $<TARGET_OBJECTS:system_io>) + +target_include_directories(${TARGET_NAME} PUBLIC ${ASAPO_CXX_COMMON_INCLUDE_DIR}) diff --git a/common/cpp/src/asapo_fabric/asapo_fabric.cpp b/common/cpp/src/asapo_fabric/asapo_fabric.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96aa4fd5215db008e4ec0a9548b8db7bba9e31a9 --- /dev/null +++ b/common/cpp/src/asapo_fabric/asapo_fabric.cpp @@ -0,0 +1,17 @@ +#include <asapo_fabric/asapo_fabric.h> + +#ifdef LIBFABRIC_ENABLED +#include "fabric_factory_impl.h" +#else +#include "fabric_factory_not_supported.h" +#endif + +using namespace asapo::fabric; + +std::unique_ptr<FabricFactory> asapo::fabric::GenerateDefaultFabricFactory() { +#ifdef LIBFABRIC_ENABLED + return std::unique_ptr<FabricFactory>(new FabricFactoryImpl()); +#else + return std::unique_ptr<FabricFactory>(new FabricFactoryNotSupported()); +#endif +} diff --git a/common/cpp/src/asapo_fabric/client/fabric_client_impl.cpp b/common/cpp/src/asapo_fabric/client/fabric_client_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..edd4d14552730921271087cdcb382071a75dd25a --- /dev/null +++ b/common/cpp/src/asapo_fabric/client/fabric_client_impl.cpp @@ -0,0 +1,93 @@ +#include "fabric_client_impl.h" +#include <rdma/fi_domain.h> +#include <cstring> + +using namespace asapo; +using namespace fabric; + +std::string FabricClientImpl::GetAddress() const { + if (!domain_) { + return ""; + } + return FabricContextImpl::GetAddress(); +} + +std::unique_ptr<FabricMemoryRegion> FabricClientImpl::ShareMemoryRegion(void* src, size_t size, Error* error) { + if (!domain_) { + *error = FabricErrorTemplates::kClientNotInitializedError.Generate(); + return nullptr; + } + return FabricContextImpl::ShareMemoryRegion(src, size, error); +} + +void FabricClientImpl::Send(FabricAddress dstAddress, FabricMessageId messageId, const void* src, size_t size, + Error* error) { + if (!domain_) { + *error = FabricErrorTemplates::kClientNotInitializedError.Generate(); + return; + } + FabricContextImpl::Send(dstAddress, messageId, src, size, error); +} + +void FabricClientImpl::Recv(FabricAddress srcAddress, FabricMessageId messageId, void* dst, size_t size, Error* error) { + if (!domain_) { + *error = FabricErrorTemplates::kClientNotInitializedError.Generate(); + return; + } + FabricContextImpl::Recv(srcAddress, messageId, dst, size, error); +} + +void +FabricClientImpl::RdmaWrite(FabricAddress dstAddress, const MemoryRegionDetails* details, const void* buffer, + size_t size, + Error* error) { + if (!domain_) { + *error = FabricErrorTemplates::kClientNotInitializedError.Generate(); + return; + } + FabricContextImpl::RdmaWrite(dstAddress, details, buffer, size, error); +} + +FabricAddress FabricClientImpl::AddServerAddress(const std::string& serverAddress, Error* error) { + std::string hostname; + uint16_t port; + std::tie(hostname, port) = *io__->SplitAddressToHostnameAndPort(serverAddress); + std::string serverIp = io__->ResolveHostnameToIp(hostname, error); + + InitIfNeeded(serverIp, error); + if (*error) { + return FI_ADDR_NOTAVAIL; + } + + int result; + FabricAddress addrIdx; + + result = fi_av_insertsvc(address_vector_, serverIp.c_str(), std::to_string(port).c_str(), + &addrIdx, 0, nullptr); + if (result != 1) { + *error = ErrorFromFabricInternal("fi_av_insertsvc", result); + return FI_ADDR_NOTAVAIL; + } + + FabricHandshakePayload handshake {}; + strcpy(handshake.hostnameAndPort, GetAddress().c_str()); + RawSend(addrIdx, &handshake, sizeof(handshake), error); + if (*error) { + return 0; + } + + // Zero sized payload + RawRecv(addrIdx, nullptr, 0, error); + + return addrIdx; +} + +void FabricClientImpl::InitIfNeeded(const std::string& targetIpHint, Error* error) { + const std::lock_guard<std::mutex> lock(initMutex_); // Will be released when scope is cleared + + if (domain_) { + return; // Was already initialized + } + + InitCommon(targetIpHint, 0, error); +} diff --git a/common/cpp/src/asapo_fabric/client/fabric_client_impl.h b/common/cpp/src/asapo_fabric/client/fabric_client_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..fe2747ac682870092de881f10bfc992ab1c3a290 --- /dev/null +++ b/common/cpp/src/asapo_fabric/client/fabric_client_impl.h @@ -0,0 +1,37 @@ +#ifndef ASAPO_FABRIC_CLIENT_IMPL_H +#define ASAPO_FABRIC_CLIENT_IMPL_H + +#include <asapo_fabric/asapo_fabric.h> +#include "../common/fabric_context_impl.h" + +namespace asapo { +namespace fabric { + +class FabricClientImpl : public FabricClient, public FabricContextImpl { + private: + std::mutex initMutex_; + public: // Link to FabricContext + std::string GetAddress() const override; + + std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) override; + + void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) override; + + void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) override; + + void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) override; + public: + FabricAddress AddServerAddress(const std::string& serverAddress, Error* error) override; + + private: + void InitIfNeeded(const std::string& targetIpHint, Error* error); +}; + +} +} + +#endif //ASAPO_FABRIC_CLIENT_IMPL_H diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4cac75c1e61600bd5013fd0c6c2c1bf19ff132f --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.cpp @@ -0,0 +1,327 @@ +#include <io/io_factory.h> +#include <cstring> +#include <rdma/fi_cm.h> +#include <rdma/fi_domain.h> +#include <rdma/fi_endpoint.h> +#include <rdma/fi_rma.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <rdma/fi_tagged.h> +#include "fabric_context_impl.h" +#include "fabric_memory_region_impl.h" + +using namespace asapo; +using namespace fabric; + +std::string __PRETTY_FUNCTION_TO_NAMESPACE__(const std::string& prettyFunction) { + auto functionParamBegin = prettyFunction.find('('); + auto spaceBegin = prettyFunction.substr(0, functionParamBegin).find(' '); + return prettyFunction.substr(spaceBegin + 1, functionParamBegin - spaceBegin - 1); +} + +// This marco checks if the call that is being made returns FI_SUCCESS. Should only be used with LiFabric functions +// *error is set to the corresponding LiFabric error +#define FI_OK(functionCall) \ + do { \ + int tmp_fi_status = functionCall; \ + if(__builtin_expect(tmp_fi_status, FI_SUCCESS)) { \ + std::string tmp_fi_s = #functionCall; \ + *error = ErrorFromFabricInternal(__PRETTY_FUNCTION_TO_NAMESPACE__(__PRETTY_FUNCTION__) + " Line " + std::to_string(__LINE__) + ": " + tmp_fi_s.substr(0, tmp_fi_s.find('(')), tmp_fi_status);\ + return; \ + } \ + } while(0) // Enforce ';' + +// TODO: It is super important that version 1.10 is installed, but since its not released yet we go with 1.9 +const uint32_t FabricContextImpl::kMinExpectedLibFabricVersion = FI_VERSION(1, 9); + +FabricContextImpl::FabricContextImpl() : io__{ GenerateDefaultIO() }, alive_check_response_task_(this) { +} + +FabricContextImpl::~FabricContextImpl() { + StopBackgroundThreads(); + + if (endpoint_) + fi_close(&endpoint_->fid); + + if (completion_queue_) + fi_close(&completion_queue_->fid); + + if (address_vector_) + fi_close(&address_vector_->fid); + + if (domain_) + fi_close(&domain_->fid); + + if (fabric_) + fi_close(&fabric_->fid); + + if (fabric_info_) + fi_freeinfo(fabric_info_); +} + +void FabricContextImpl::InitCommon(const std::string& networkIpHint, uint16_t serverListenPort, Error* error) { + const bool isServer = serverListenPort != 0; + + // The server must know where the packages are coming from, FI_SOURCE allows this. + uint64_t additionalFlags = isServer ? FI_SOURCE : 0; + + fi_info* hints = fi_allocinfo(); + if (networkIpHint == "127.0.0.1") { + // sockets mode + hints->fabric_attr->prov_name = strdup("sockets"); + hotfix_using_sockets_ = true; + } else { + // verbs mode + hints->fabric_attr->prov_name = strdup("verbs;ofi_rxm"); + } + hints->ep_attr->type = FI_EP_RDM; + hints->caps = FI_TAGGED | FI_RMA | FI_DIRECTED_RECV | additionalFlags; + + if (isServer) { + hints->src_addr = strdup(networkIpHint.c_str()); + } else { + hints->dest_addr = strdup(networkIpHint.c_str()); + } + + // I've deliberately removed the FI_MR_LOCAL flag, which forces the user of the API to pre register the + // memory that is going to be transferred via RDMA. + // Since performance tests showed that the performance is roughly equal I've removed it. + hints->domain_attr->mr_mode = FI_MR_ALLOCATED | FI_MR_VIRT_ADDR | FI_MR_PROV_KEY;// | FI_MR_LOCAL; + hints->addr_format = FI_SOCKADDR_IN; + + int ret = fi_getinfo( + kMinExpectedLibFabricVersion, networkIpHint.c_str(), isServer ? std::to_string(serverListenPort).c_str() : nullptr, + additionalFlags, hints, &fabric_info_); + + if (ret) { + if (ret == -FI_ENODATA) { + *error = FabricErrorTemplates::kNoDeviceFoundError.Generate(); + } else { + *error = ErrorFromFabricInternal("fi_getinfo", ret); + } + fi_freeinfo(hints); + return; + } + // fprintf(stderr, fi_tostr(fabric_info_, FI_TYPE_INFO)); // Print the found fabric details + + // We have to reapply the memory mode because they get resetted + fabric_info_->domain_attr->mr_mode = hints->domain_attr->mr_mode; + + // total_buffered_recv is a hint to the provider of the total available space that may be needed to buffer messages + // that are received for which there is no matching receive operation. + // fabric_info_->rx_attr->total_buffered_recv = 0; + // If something strange is happening with receive requests, we should set this to 0. + + fi_freeinfo(hints); + + FI_OK(fi_fabric(fabric_info_->fabric_attr, &fabric_, nullptr)); + FI_OK(fi_domain(fabric_, fabric_info_, &domain_, nullptr)); + + fi_av_attr av_attr{}; + FI_OK(fi_av_open(domain_, &av_attr, &address_vector_, nullptr)); + + fi_cq_attr cq_attr{}; + if (serverListenPort) { + // The server must know where the data is coming from(FI_SOURCE) and what the MessageId(TAG) is. + cq_attr.format = FI_CQ_FORMAT_TAGGED; + } + cq_attr.wait_obj = FI_WAIT_UNSPEC; // Allow the wait of querying the cq + FI_OK(fi_cq_open(domain_, &cq_attr, &completion_queue_, nullptr)); + + FI_OK(fi_endpoint(domain_, fabric_info_, &endpoint_, nullptr)); + FI_OK(fi_ep_bind(endpoint_, &address_vector_->fid, 0)); + FI_OK(fi_ep_bind(endpoint_, &completion_queue_->fid, FI_RECV | FI_SEND)); + + FI_OK(fi_enable(endpoint_)); + + StartBackgroundThreads(); +} + +std::string FabricContextImpl::GetAddress() const { + sockaddr_in sin{}; + size_t sin_size = sizeof(sin); + fi_getname(&(endpoint_->fid), &sin, &sin_size); + + // TODO Maybe expose such a function to io__ + switch(sin.sin_family) { + case AF_INET: + return std::string(inet_ntoa(sin.sin_addr)) + ":" + std::to_string(ntohs(sin.sin_port)); + default: + throw std::runtime_error("Unknown addr family: " + std::to_string(sin.sin_family)); + } +} + +std::unique_ptr<FabricMemoryRegion> FabricContextImpl::ShareMemoryRegion(void* src, size_t size, Error* error) { + fid_mr* mr{}; + auto region = std::unique_ptr<FabricMemoryRegionImpl>(new FabricMemoryRegionImpl()); + int ret = fi_mr_reg(domain_, src, size, + FI_REMOTE_READ | FI_REMOTE_WRITE | FI_SEND | FI_RECV, + 0, 0, 0, &mr, region.get()); + + if (ret != 0) { + *error = ErrorFromFabricInternal("fi_mr_reg", ret); + return nullptr; + } + + region->SetArguments(mr, (uint64_t)src, size); + return std::unique_ptr<FabricMemoryRegion>(region.release()); +} + +void FabricContextImpl::Send(FabricAddress dstAddress, FabricMessageId messageId, const void* src, size_t size, + Error* error) { + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_tsend, src, size, nullptr, dstAddress, messageId); +} + +void FabricContextImpl::Recv(FabricAddress srcAddress, FabricMessageId messageId, void* dst, size_t size, + Error* error) { + HandleFiCommandWithBasicTaskAndWait(srcAddress, error, + fi_trecv, dst, size, nullptr, srcAddress, messageId, kRecvTaggedExactMatch); +} + +void FabricContextImpl::RawSend(FabricAddress dstAddress, const void* src, size_t size, Error* error) { + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_send, src, size, nullptr, dstAddress); +} + +void FabricContextImpl::RawRecv(FabricAddress srcAddress, void* dst, size_t size, Error* error) { + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, error, + fi_recv, dst, size, nullptr, srcAddress); +} + +void +FabricContextImpl::RdmaWrite(FabricAddress dstAddress, const MemoryRegionDetails* details, const void* buffer, + size_t size, + Error* error) { + HandleFiCommandWithBasicTaskAndWait(dstAddress, error, + fi_write, buffer, size, nullptr, dstAddress, details->addr, details->key); + +} + +void FabricContextImpl::StartBackgroundThreads() { + background_threads_running_ = true; + + completion_thread_ = io__->NewThread("ASAPO/FI/CQ", [this]() { + CompletionThread(); + }); + + alive_check_response_task_.Start(); +} + +void FabricContextImpl::StopBackgroundThreads() { + alive_check_response_task_.Stop(); // This has to be done before we kill the completion thread + + background_threads_running_ = false; + if (completion_thread_) { + completion_thread_->join(); + completion_thread_ = nullptr; + } +} + +void FabricContextImpl::CompletionThread() { + Error error; + fi_cq_tagged_entry entry{}; + FabricAddress tmpAddress; + while(background_threads_running_ && !error) { + ssize_t ret; + ret = fi_cq_sreadfrom(completion_queue_, &entry, 1, &tmpAddress, nullptr, 10 /*ms*/); + + switch (ret) { + case -FI_EAGAIN: // No data + std::this_thread::yield(); + break; + case -FI_EAVAIL: // An error is in the queue + CompletionThreadHandleErrorAvailable(&error); + break; + case 1: { // We got 1 data entry back + auto task = (FabricWaitableTask*)(entry.op_context); + if (task) { + task->HandleCompletion(&entry, tmpAddress); + } else { + error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_sreadfrom"); + } + break; + } + default: + error = ErrorFromFabricInternal("Unknown error while fi_cq_readfrom", ret); + break; + } + } + + if (error) { + throw std::runtime_error("ASAPO Fabric CompletionThread exited with error: " + error->Explain()); + } +} + +void FabricContextImpl::CompletionThreadHandleErrorAvailable(Error* error) { + fi_cq_err_entry errEntry{}; + ssize_t ret = fi_cq_readerr(completion_queue_, &errEntry, 0); + if (ret != 1) { + *error = ErrorFromFabricInternal("Unknown error while fi_cq_readerr", ret); + } else { + auto task = (FabricWaitableTask*)(errEntry.op_context); + if (task) { + task->HandleErrorCompletion(&errEntry); + } else if (hotfix_using_sockets_) { + printf("[Known Sockets bug libfabric/#5795] Ignoring nullptr task!\n"); + } else { + *error = FabricErrorTemplates::kInternalError.Generate("nullptr context from fi_cq_readerr"); + } + } +} + +bool FabricContextImpl::TargetIsAliveCheck(FabricAddress address) { + Error error; + + HandleFiCommandWithBasicTaskAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, &error, + fi_tsend, nullptr, 0, nullptr, address, FI_ASAPO_TAG_ALIVE_CHECK); + // If the send was successful, then we are still able to communicate with the peer + return !(error != nullptr); +} + +void FabricContextImpl::InternalWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error) { + + // Check if we simply can wait for our task + task->Wait(requestTimeoutMs_, error); + + if (*error == FabricErrorTemplates::kTimeout) { + if (targetAddress == FI_ASAPO_ADDR_NO_ALIVE_CHECK) { + CancelTask(task, error); + // We expect the task to fail with 'Operation canceled' + if (*error == FabricErrorTemplates::kInternalOperationCanceledError) { + // Switch it to a timeout so its more clearly what happened + *error = FabricErrorTemplates::kTimeout.Generate(); + } + } else { + InternalWaitWithAliveCheck(targetAddress, task, error); + } + } +} + +void FabricContextImpl::InternalWaitWithAliveCheck(FabricAddress targetAddress, FabricWaitableTask* task, + Error* error) {// Handle advanced alive check + bool aliveCheckFailed = false; + for (uint32_t i = 0; i < maxTimeoutRetires_ && *error == FabricErrorTemplates::kTimeout; i++) { + *error = nullptr; + printf("HandleFiCommandAndWait - Tries: %d\n", i); + if (!TargetIsAliveCheck(targetAddress)) { + aliveCheckFailed = true; + break; + } + task->Wait(requestTimeoutMs_, error); + } + + CancelTask(task, error); + + if (aliveCheckFailed) { + *error = FabricErrorTemplates::kInternalConnectionError.Generate(); + } else if(*error == FabricErrorTemplates::kInternalOperationCanceledError) { + *error = FabricErrorTemplates::kTimeout.Generate(); + } +} + +void FabricContextImpl::CancelTask(FabricWaitableTask* task, Error* error) { + *error = nullptr; + fi_cancel(&endpoint_->fid, task); + task->Wait(0, error); // You can probably expect a kInternalOperationCanceledError +} diff --git a/common/cpp/src/asapo_fabric/common/fabric_context_impl.h b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8d51c4cb18a57f62509faff4a19d9538de508ed3 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/fabric_context_impl.h @@ -0,0 +1,171 @@ +#ifndef ASAPO_FABRIC_CONTEXT_IMPL_H +#define ASAPO_FABRIC_CONTEXT_IMPL_H + +#include <io/io.h> +#include <rdma/fabric.h> +#include <rdma/fi_endpoint.h> +#include <memory> +#include <asapo_fabric/asapo_fabric.h> +#include <thread> +#include "task/fabric_waitable_task.h" +#include "../fabric_internal_error.h" +#include "task/fabric_alive_check_response_task.h" + +namespace asapo { +namespace fabric { + +#define FI_ASAPO_ADDR_NO_ALIVE_CHECK FI_ADDR_NOTAVAIL +#define FI_ASAPO_TAG_ALIVE_CHECK ((uint64_t) -1) + +/** + * TODO: State of the bandages used in asapo to use RXM + * If you read this _in the future_ there are hopefully fixes for the following topics: + * Since RXM is connectionless, we do not know when an disconnect occurs. + * - Therefore when we try to receive data, we have added a targetAddress to HandleFiCommandAndWait, + * which might check if the peer is still responding to pings when a timeout occurs. + * + * Another issue is that in order to send data all addresses have to be added in an addressVector, + * unfortunately, this is also required to respond to a request. + * - So we added a handshake procedure that sends the local address of the client with a handshake to the server. + * This could be fixed by FI_SOURCE_ERR, which automatically + * adds new connections the AV which would obsolete the handshake. + * At the time of writing this, FI_SOURCE_ERR is not supported with verbs;ofi_rxm + */ + + +const static uint64_t kRecvTaggedAnyMatch = ~0ULL; +const static uint64_t kRecvTaggedExactMatch = 0; + +// TODO Use a serialization framework +struct FabricHandshakePayload { + // Hostnames can be up to 256 Bytes long. We also need to store the port number. + char hostnameAndPort[512]; +}; + +class FabricContextImpl : public FabricContext { + friend class FabricSelfRequeuingTask; + friend class FabricAliveCheckResponseTask; + public: + std::unique_ptr<IO> io__; + + protected: + FabricAliveCheckResponseTask alive_check_response_task_; + + fi_info* fabric_info_{}; + fid_fabric* fabric_{}; + fid_domain* domain_{}; + fid_cq* completion_queue_{}; + fid_av* address_vector_{}; + fid_ep* endpoint_{}; + + uint64_t requestEnqueueTimeoutMs_ = 10000; // 10 sec for queuing a task + uint64_t requestTimeoutMs_ = 20000; // 20 sec to complete a task, otherwise a ping will be send + uint32_t maxTimeoutRetires_ = 5; // Timeout retires, if one of them fails, the task will fail with a timeout + + std::unique_ptr<std::thread> completion_thread_; + bool background_threads_running_ = false; + private: + // Unfortunately when a client disconnects on sockets, a weird completion is generated. See libfabric/#5795 + bool hotfix_using_sockets_ = false; + public: + explicit FabricContextImpl(); + virtual ~FabricContextImpl(); + + static const uint32_t kMinExpectedLibFabricVersion; + + std::string GetAddress() const override; + + /// The memory will be shared until the result is freed + std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) override; + + /// With message id + void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) override; + void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) override; + + /// Without message id - No alive check! + void RawSend(FabricAddress dstAddress, + const void* src, size_t size, Error* error); + void RawRecv(FabricAddress srcAddress, + void* dst, size_t size, Error* error); + + /// Rdma + void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) override; + + protected: + /// If client serverListenPort == 0 + void InitCommon(const std::string& networkIpHint, uint16_t serverListenPort, Error* error); + + void StartBackgroundThreads(); + void StopBackgroundThreads(); + + // If the targetAddress is FI_ASAPO_ADDR_NO_ALIVE_CHECK and a timeout occurs, no further ping is being done. + // Alive check is generally only necessary if you are trying to receive data or RDMA send. + template<class FuncType, class... ArgTypes> + inline void HandleFiCommandWithBasicTaskAndWait(FabricAddress targetAddress, Error* error, + FuncType func, ArgTypes... args) { + FabricWaitableTask task; + HandleFiCommandAndWait(targetAddress, &task, error, func, args...); + } + + template<class FuncType, class... ArgTypes> + inline void HandleFiCommandAndWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error, + FuncType func, ArgTypes... args) { + HandleRawFiCommand(task, error, func, args...); + if (!(*error)) { // We successfully queued our request + InternalWait(targetAddress, task, error); + } + } + + template<class FuncType, class... ArgTypes> + inline void HandleRawFiCommand(void* context, Error* error, FuncType func, ArgTypes... args) { + ssize_t ret; + // Since handling timeouts is an overhead, we first try to send the data regularly + ret = func(endpoint_, args..., context); + if (ret == -FI_EAGAIN) { + using namespace std::chrono; + using clock = std::chrono::high_resolution_clock; + auto maxTime = clock::now() + milliseconds(requestEnqueueTimeoutMs_); + + do { + std::this_thread::sleep_for(milliseconds(3)); + ret = func(endpoint_, args..., context); + } while (ret == -FI_EAGAIN && maxTime >= clock::now()); + } + + switch (-ret) { + case FI_SUCCESS: + // Success + break; + case FI_EAGAIN: // We felt trough our own timeout loop + *error = FabricErrorTemplates::kTimeout.Generate(); + break; + case FI_ENOENT: + *error = FabricErrorTemplates::kConnectionRefusedError.Generate(); + break; + default: + *error = ErrorFromFabricInternal("HandleRawFiCommand", ret); + break; + } + } + + private: + bool TargetIsAliveCheck(FabricAddress address); + void CompletionThread(); + + void InternalWait(FabricAddress targetAddress, FabricWaitableTask* task, Error* error); + + void InternalWaitWithAliveCheck(FabricAddress targetAddress, FabricWaitableTask* task, Error* error); + + void CompletionThreadHandleErrorAvailable(Error* error); + + void CancelTask(FabricWaitableTask* task, Error* error); +}; + +} +} + +#endif //ASAPO_FABRIC_CONTEXT_IMPL_H diff --git a/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.cpp b/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac1a264b23507a48af95ef72e38302d6b9369c4d --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.cpp @@ -0,0 +1,21 @@ +#include "fabric_memory_region_impl.h" + +using namespace asapo; +using namespace fabric; + +FabricMemoryRegionImpl::~FabricMemoryRegionImpl() { + if (mr_) { + fi_close(&mr_->fid); + } +} + +void FabricMemoryRegionImpl::SetArguments(fid_mr* mr, uint64_t address, uint64_t length) { + mr_ = mr; + details_.addr = address; + details_.length = length; + details_.key = fi_mr_key(mr_); +} + +const MemoryRegionDetails* FabricMemoryRegionImpl::GetDetails() const { + return &details_; +} diff --git a/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.h b/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..44d7c8f6d8c1853706815aab42825c7e93c506dc --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/fabric_memory_region_impl.h @@ -0,0 +1,24 @@ +#ifndef ASAPO_FABRIC_MEMORY_REGION_IMPL_H +#define ASAPO_FABRIC_MEMORY_REGION_IMPL_H + +#include <asapo_fabric/asapo_fabric.h> +#include <rdma/fi_domain.h> + +namespace asapo { +namespace fabric { +class FabricMemoryRegionImpl : public FabricMemoryRegion { + private: + fid_mr* mr_{}; + MemoryRegionDetails details_{}; + public: + ~FabricMemoryRegionImpl() override; + + void SetArguments(fid_mr* mr, uint64_t address, uint64_t length); + + const MemoryRegionDetails* GetDetails() const override; +}; +} +} + + +#endif //ASAPO_FABRIC_MEMORY_REGION_IMPL_H diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2875a1a289c230842be85a2b12ef9905fa78c152 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.cpp @@ -0,0 +1,28 @@ +#include <rdma/fi_tagged.h> +#include "fabric_alive_check_response_task.h" +#include "../fabric_context_impl.h" + +using namespace asapo; +using namespace fabric; + +void FabricAliveCheckResponseTask::RequeueSelf() { + Error tmpError = nullptr; + + ParentContext()->HandleRawFiCommand(this, &tmpError, + fi_trecv, nullptr, 0, nullptr, FI_ADDR_UNSPEC, FI_ASAPO_TAG_ALIVE_CHECK, kRecvTaggedExactMatch); + + // Error is ignored +} + +void FabricAliveCheckResponseTask::OnCompletion(const fi_cq_tagged_entry*, FabricAddress) { + // We received a ping, LibFabric will automatically notify the sender about the completion. +} + +void FabricAliveCheckResponseTask::OnErrorCompletion(const fi_cq_err_entry*) { + // Error is ignored +} + +FabricAliveCheckResponseTask::FabricAliveCheckResponseTask(FabricContextImpl* parentContext) + : FabricSelfRequeuingTask(parentContext) { + +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h new file mode 100644 index 0000000000000000000000000000000000000000..e3f55abf9256948747988e7e8e5df570e90c6970 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_alive_check_response_task.h @@ -0,0 +1,25 @@ +#ifndef ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H +#define ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H + +#include "fabric_self_requeuing_task.h" + +namespace asapo { +namespace fabric { + +/** + * This is the counter part of FabricContextImpl.TargetIsAliveCheck + */ +class FabricAliveCheckResponseTask : public FabricSelfRequeuingTask { + public: + explicit FabricAliveCheckResponseTask(FabricContextImpl* parentContext); + protected: + void RequeueSelf() override; + + void OnCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; + + void OnErrorCompletion(const fi_cq_err_entry* errEntry) override; +}; +} +} + +#endif //ASAPO_FABRIC_ALIVE_CHECK_RESPONSE_TASK_H diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bef89058e090ccb013600b435b5f6230c6d01498 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.cpp @@ -0,0 +1,13 @@ +#include "fabric_self_deleting_task.h" + +void asapo::fabric::FabricSelfDeletingTask::HandleCompletion(const fi_cq_tagged_entry*, FabricAddress) { + OnDone(); +} + +void asapo::fabric::FabricSelfDeletingTask::HandleErrorCompletion(const fi_cq_err_entry*) { + OnDone(); +} + +void asapo::fabric::FabricSelfDeletingTask::OnDone() { + delete this; +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h new file mode 100644 index 0000000000000000000000000000000000000000..59d5f627a8fc75fc2be8a67d1d069cfa967e1c0a --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_deleting_task.h @@ -0,0 +1,22 @@ +#ifndef ASAPO_FABRIC_SELF_DELETING_TASK_H +#define ASAPO_FABRIC_SELF_DELETING_TASK_H + +#include "fabric_task.h" + +namespace asapo { +namespace fabric { + +class FabricSelfDeletingTask : FabricTask { + + void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) final; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) final; + + private: + virtual ~FabricSelfDeletingTask() = default; + void OnDone(); +}; + +} +} + +#endif //ASAPO_FABRIC_SELF_DELETING_TASK_H diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75b18f1c440b5fc7afe5c868ff6789d1e9a5d088 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.cpp @@ -0,0 +1,54 @@ +#include "fabric_self_requeuing_task.h" +#include "../fabric_context_impl.h" + +using namespace asapo; +using namespace fabric; + +FabricSelfRequeuingTask::~FabricSelfRequeuingTask() { + Stop(); +} + +FabricSelfRequeuingTask::FabricSelfRequeuingTask(FabricContextImpl* parentContext) : stop_response_future_{stop_response_.get_future()} { + parent_context_ = parentContext; +} + +void FabricSelfRequeuingTask::Start() { + if (was_queued_already_) { + throw std::runtime_error("FabricSelfRequeuingTask can only be queued once"); + } + was_queued_already_ = true; + RequeueSelf(); +} + +void FabricSelfRequeuingTask::Stop() { + if (was_queued_already_ && still_running_) { + still_running_ = false; + fi_cancel(&parent_context_->endpoint_->fid, this); + stop_response_future_.wait(); + } +} + +void FabricSelfRequeuingTask::HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) { + OnCompletion(entry, source); + AfterCompletion(); +} + +void FabricSelfRequeuingTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { + // If we are not running and got a FI_ECANCELED its probably expected. + if (still_running_ || errEntry->err != FI_ECANCELED) { + OnErrorCompletion(errEntry); + } + AfterCompletion(); +} + +void FabricSelfRequeuingTask::AfterCompletion() { + if (still_running_) { + RequeueSelf(); + } else { + stop_response_.set_value(); + } +} + +FabricContextImpl* FabricSelfRequeuingTask::ParentContext() { + return parent_context_; +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h new file mode 100644 index 0000000000000000000000000000000000000000..905b6f1dbb00a9669722711d898c70d970702424 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_self_requeuing_task.h @@ -0,0 +1,40 @@ +#ifndef ASAPO_FABRIC_SELF_REQUEUING_TASK_H +#define ASAPO_FABRIC_SELF_REQUEUING_TASK_H + +#include <future> +#include "fabric_task.h" + +namespace asapo { +namespace fabric { +class FabricContextImpl; + +class FabricSelfRequeuingTask : public FabricTask { + private: + FabricContextImpl* parent_context_; + volatile bool still_running_ = true; + bool was_queued_already_ = false; + std::promise<void> stop_response_; + std::future<void> stop_response_future_; + public: + ~FabricSelfRequeuingTask(); + explicit FabricSelfRequeuingTask(FabricContextImpl* parentContext); + + void Start(); + void Stop(); + public: + void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) final; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) final; + protected: + FabricContextImpl* ParentContext(); + + virtual void RequeueSelf() = 0; + virtual void OnCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) = 0; + virtual void OnErrorCompletion(const fi_cq_err_entry* errEntry) = 0; + private: + void AfterCompletion(); +}; +} +} + + +#endif //ASAPO_FABRIC_SELF_REQUEUING_TASK_H diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_task.h new file mode 100644 index 0000000000000000000000000000000000000000..3802a558d4a4309fde71ce2504a478d149aa2e45 --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_task.h @@ -0,0 +1,17 @@ +#ifndef ASAPO_FABRIC_TASK_H +#define ASAPO_FABRIC_TASK_H + +#include <asapo_fabric/asapo_fabric.h> +#include <rdma/fi_eq.h> + +namespace asapo { +namespace fabric { +class FabricTask { + public: + virtual void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) = 0; + virtual void HandleErrorCompletion(const fi_cq_err_entry* errEntry) = 0; +}; +} +} + +#endif //ASAPO_FABRIC_TASK_H diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..47efa2fe8f558d934cbdf7969d44ad40710548fb --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.cpp @@ -0,0 +1,35 @@ +#include "fabric_waitable_task.h" +#include "../../fabric_internal_error.h" + +using namespace asapo; +using namespace fabric; + +FabricWaitableTask::FabricWaitableTask() : future_{promise_.get_future()}, source_{FI_ADDR_NOTAVAIL} { + +} + +void FabricWaitableTask::HandleCompletion(const fi_cq_tagged_entry*, FabricAddress source) { + source_ = source; + promise_.set_value(); +} + +void FabricWaitableTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { + error_ = ErrorFromFabricInternal("FabricWaitableTask", -errEntry->err); + promise_.set_value(); +} + +void FabricWaitableTask::Wait(uint32_t sleepInMs, Error* error) { + if (sleepInMs) { + if (future_.wait_for(std::chrono::milliseconds(sleepInMs)) == std::future_status::timeout) { + *error = FabricErrorTemplates::kTimeout.Generate(); + return; + } + } else { + future_.wait(); + } + error->swap(error_); +} + +FabricAddress FabricWaitableTask::GetSource() const { + return source_; +} diff --git a/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h new file mode 100644 index 0000000000000000000000000000000000000000..24a6b565969f48de74774c8745c8da5186d7e92e --- /dev/null +++ b/common/cpp/src/asapo_fabric/common/task/fabric_waitable_task.h @@ -0,0 +1,32 @@ +#ifndef ASAPO_FABRIC_WAITABLE_TASK_H +#define ASAPO_FABRIC_WAITABLE_TASK_H + +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <future> +#include "fabric_task.h" + +namespace asapo { +namespace fabric { +class FabricWaitableTask : FabricTask { + private: + std::promise<void> promise_; + std::future<void> future_; + + Error error_; + FabricAddress source_; + public: + explicit FabricWaitableTask(); + + void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) override; + + void Wait(uint32_t sleepInMs, Error* error); + + FabricAddress GetSource() const; + +}; +} +} + +#endif //ASAPO_FABRIC_WAITABLE_TASK_H diff --git a/common/cpp/src/asapo_fabric/fabric_factory_impl.cpp b/common/cpp/src/asapo_fabric/fabric_factory_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ae10ed3bb131a3cefb198341dd1d11f52fe6ab1 --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_factory_impl.cpp @@ -0,0 +1,49 @@ +#include "fabric_factory_impl.h" +#include "fabric_internal_error.h" +#include "client/fabric_client_impl.h" +#include "server/fabric_server_impl.h" +#include <rdma/fabric.h> + +using namespace asapo::fabric; + +std::string fi_version_string(uint32_t version) { + return std::to_string(FI_MAJOR(version)) + "." + std::to_string(FI_MINOR(version)); +} + +bool FabricFactoryImpl::HasValidVersion(Error* error) const { + auto current_version = fi_version(); + + if (FI_VERSION_LT(current_version, FabricContextImpl::kMinExpectedLibFabricVersion)) { + std::string found_version_str = fi_version_string(current_version); + std::string expected_version_str = fi_version_string(FabricContextImpl::kMinExpectedLibFabricVersion); + + std::string errorText = "Found " + found_version_str + " but expected at least " + expected_version_str; + *error = FabricErrorTemplates::kOutdatedLibraryError.Generate(errorText); + return false; + } + + return true; +} + +std::unique_ptr<FabricServer> +FabricFactoryImpl::CreateAndBindServer(const AbstractLogger* logger, const std::string& host, uint16_t port, + Error* error) const { + if (!HasValidVersion(error)) { + return nullptr; + } + + auto server = new FabricServerImpl(logger); + + server->InitAndStartServer(host, port, error); + + return std::unique_ptr<FabricServer>(server); +} + +std::unique_ptr<FabricClient> +FabricFactoryImpl::CreateClient(Error* error) const { + if (!HasValidVersion(error)) { + return nullptr; + } + + return std::unique_ptr<FabricClient>(new FabricClientImpl()); +} diff --git a/common/cpp/src/asapo_fabric/fabric_factory_impl.h b/common/cpp/src/asapo_fabric/fabric_factory_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..ce0ec84eeb0d4e5a61a0af811c71ba8f50846c6c --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_factory_impl.h @@ -0,0 +1,21 @@ +#include <asapo_fabric/asapo_fabric.h> + +#ifndef ASAPO_FABRIC_FACTORY_IMPL_H +#define ASAPO_FABRIC_FACTORY_IMPL_H + +namespace asapo { +namespace fabric { +class FabricFactoryImpl : public FabricFactory { + public: + bool HasValidVersion(Error* error) const; + + std::unique_ptr<FabricServer> + CreateAndBindServer(const AbstractLogger* logger, + const std::string& host, uint16_t port, Error* error) const override; + + std::unique_ptr<FabricClient> CreateClient(Error* error) const override; +}; +} +} + +#endif //ASAPO_FABRIC_FACTORY_IMPL_H diff --git a/common/cpp/src/asapo_fabric/fabric_factory_not_supported.cpp b/common/cpp/src/asapo_fabric/fabric_factory_not_supported.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09e33cd8cb71bee6c1740a6c53b97e375ac641d9 --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_factory_not_supported.cpp @@ -0,0 +1,16 @@ +#include "fabric_factory_not_supported.h" +#include "fabric_internal_error.h" + +using namespace asapo::fabric; + +std::unique_ptr<FabricServer> asapo::fabric::FabricFactoryNotSupported::CreateAndBindServer( + const AbstractLogger* logger, const std::string& host, uint16_t port, + Error* error) const { + *error = FabricErrorTemplates::kNotSupportedOnBuildError.Generate(); + return nullptr; +} + +std::unique_ptr<FabricClient> asapo::fabric::FabricFactoryNotSupported::CreateClient(Error* error) const { + *error = FabricErrorTemplates::kNotSupportedOnBuildError.Generate(); + return nullptr; +} diff --git a/common/cpp/src/asapo_fabric/fabric_factory_not_supported.h b/common/cpp/src/asapo_fabric/fabric_factory_not_supported.h new file mode 100644 index 0000000000000000000000000000000000000000..789fe4e031eda5d096deeb300db654019d9b4400 --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_factory_not_supported.h @@ -0,0 +1,17 @@ +#ifndef ASAPO_FABRIC_FACTORY_NOT_SUPPORTED_H +#define ASAPO_FABRIC_FACTORY_NOT_SUPPORTED_H + +#include <asapo_fabric/asapo_fabric.h> + +namespace asapo { +namespace fabric { +class FabricFactoryNotSupported : public FabricFactory { + std::unique_ptr<FabricServer> CreateAndBindServer( + const AbstractLogger* logger, const std::string& host, uint16_t port, Error* error) const override; + + std::unique_ptr<FabricClient> CreateClient(Error* error) const override; +}; +} +} + +#endif //ASAPO_FABRIC_FACTORY_NOT_SUPPORTED_H diff --git a/common/cpp/src/asapo_fabric/fabric_internal_error.cpp b/common/cpp/src/asapo_fabric/fabric_internal_error.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb8629e09f447a836f28496a09b2dfd8f8dfeb4b --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_internal_error.cpp @@ -0,0 +1,14 @@ +#include "fabric_internal_error.h" +#include <rdma/fi_errno.h> +#include <asapo_fabric/fabric_error.h> + +asapo::Error asapo::fabric::ErrorFromFabricInternal(const std::string& where, int internalStatusCode) { + std::string errText = where + ": " + fi_strerror(-internalStatusCode); + switch (-internalStatusCode) { + case FI_ECANCELED: + return FabricErrorTemplates::kInternalOperationCanceledError.Generate(errText); + case FI_EIO: + return FabricErrorTemplates::kInternalConnectionError.Generate(errText); + } + return FabricErrorTemplates::kInternalError.Generate(errText); +} diff --git a/common/cpp/src/asapo_fabric/fabric_internal_error.h b/common/cpp/src/asapo_fabric/fabric_internal_error.h new file mode 100644 index 0000000000000000000000000000000000000000..f057872cd26ffd8f4ffa47b00334504004048c1c --- /dev/null +++ b/common/cpp/src/asapo_fabric/fabric_internal_error.h @@ -0,0 +1,18 @@ +#ifndef ASAPO_FABRICERRORCONVERTER_H +#define ASAPO_FABRICERRORCONVERTER_H + +#include <common/error.h> + +namespace asapo { +namespace fabric { + +/** + * internalStatusCode must be a negative number + * (Which all libfabric api calls usually return in an error case + */ +Error ErrorFromFabricInternal(const std::string& where, int internalStatusCode); + +} +} + +#endif //ASAPO_FABRICERRORCONVERTER_H diff --git a/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp b/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cabf224c25d4ca28ced74b32aac7562daa483a34 --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/fabric_server_impl.cpp @@ -0,0 +1,61 @@ +#include "fabric_server_impl.h" +#include "task/fabric_recv_any_task.h" +#include <rdma/fi_tagged.h> + +using namespace asapo; +using namespace fabric; + +FabricServerImpl::~FabricServerImpl() { + accepting_task_.Stop(); +} + +FabricServerImpl::FabricServerImpl(const AbstractLogger* logger) + : log__{logger}, accepting_task_(this) { +} + +std::string FabricServerImpl::GetAddress() const { + return FabricContextImpl::GetAddress(); +} + +std::unique_ptr<FabricMemoryRegion> FabricServerImpl::ShareMemoryRegion(void* src, size_t size, Error* error) { + return FabricContextImpl::ShareMemoryRegion(src, size, error); +} + +void FabricServerImpl::Send(FabricAddress dstAddress, FabricMessageId messageId, const void* src, size_t size, + Error* error) { + FabricContextImpl::Send(dstAddress, messageId, src, size, error); +} + +void FabricServerImpl::Recv(FabricAddress srcAddress, FabricMessageId messageId, void* dst, size_t size, Error* error) { + FabricContextImpl::Recv(srcAddress, messageId, dst, size, error); +} + +void +FabricServerImpl::RdmaWrite(FabricAddress dstAddress, const MemoryRegionDetails* details, const void* buffer, + size_t size, + Error* error) { + FabricContextImpl::RdmaWrite(dstAddress, details, buffer, size, error); +} + +void +FabricServerImpl::RecvAny(FabricAddress* srcAddress, FabricMessageId* messageId, void* dst, size_t size, Error* error) { + FabricRecvAnyTask anyTask; + HandleFiCommandAndWait(FI_ASAPO_ADDR_NO_ALIVE_CHECK, &anyTask, error, + fi_trecv, dst, size, nullptr, FI_ADDR_UNSPEC, 0, kRecvTaggedAnyMatch); + + if (!(*error)) { + if (anyTask.GetSource() == FI_ADDR_NOTAVAIL) { + *error = FabricErrorTemplates::kInternalError.Generate("Source address is unavailable"); + } + *messageId = anyTask.GetMessageId(); + *srcAddress = anyTask.GetSource(); + } +} + +void FabricServerImpl::InitAndStartServer(const std::string& host, uint16_t port, Error* error) { + InitCommon(host, port, error); + + if (!(*error)) { + accepting_task_.Start(); + } +} diff --git a/common/cpp/src/asapo_fabric/server/fabric_server_impl.h b/common/cpp/src/asapo_fabric/server/fabric_server_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..0e18da62ee2bffc7d3e5504600a93091347aa9b3 --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/fabric_server_impl.h @@ -0,0 +1,44 @@ +#ifndef ASAPO_FABRIC_SERVER_IMPL_H +#define ASAPO_FABRIC_SERVER_IMPL_H + +#include <asapo_fabric/asapo_fabric.h> +#include "../common/fabric_context_impl.h" +#include "../fabric_factory_impl.h" +#include "task/fabric_handshake_accepting_task.h" + +namespace asapo { +namespace fabric { + +class FabricServerImpl : public FabricServer, public FabricContextImpl { + friend FabricFactoryImpl; + friend class FabricHandshakeAcceptingTask; + + private: + const AbstractLogger* log__; + FabricHandshakeAcceptingTask accepting_task_; + void InitAndStartServer(const std::string& host, uint16_t port, Error* error); + public: + ~FabricServerImpl() override; + explicit FabricServerImpl(const AbstractLogger* logger); + public: // Link to FabricContext + std::string GetAddress() const override; + + std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) override; + + void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) override; + + void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) override; + + void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) override; + public: + void RecvAny(FabricAddress* srcAddress, FabricMessageId* messageId, void* dst, size_t size, Error* error) override; +}; + +} +} + +#endif //ASAPO_FABRIC_SERVER_IMPL_H diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44ed14d06149e743b8c8723fdbb29362539ac364 --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.cpp @@ -0,0 +1,66 @@ +#include <rdma/fi_endpoint.h> +#include "fabric_handshake_accepting_task.h" +#include "../fabric_server_impl.h" +#include "../../common/task/fabric_self_deleting_task.h" + +using namespace asapo; +using namespace fabric; + +FabricHandshakeAcceptingTask::FabricHandshakeAcceptingTask(FabricServerImpl* parentServerContext) + : FabricSelfRequeuingTask(parentServerContext) { +} + +FabricServerImpl* FabricHandshakeAcceptingTask::ServerContext() { + return dynamic_cast<FabricServerImpl*>(ParentContext()); +} + +void FabricHandshakeAcceptingTask::RequeueSelf() { + Error ignored; + ServerContext()->HandleRawFiCommand(this, &ignored, + fi_recv, &handshake_payload_, sizeof(handshake_payload_), nullptr, FI_ADDR_UNSPEC); +} + +void FabricHandshakeAcceptingTask::OnCompletion(const fi_cq_tagged_entry*, FabricAddress) { + Error error; + HandleAccept(&error); + if (error) { + OnError(&error); + return; + } +} + +void FabricHandshakeAcceptingTask::OnErrorCompletion(const fi_cq_err_entry* errEntry) { + Error error; + error = ErrorFromFabricInternal("FabricWaitableTask", -errEntry->err); + OnError(&error); +} + +void FabricHandshakeAcceptingTask::HandleAccept(Error* error) { + auto server = ServerContext(); + std::string hostname; + uint16_t port; + std::tie(hostname, port) = + *(server->io__->SplitAddressToHostnameAndPort(handshake_payload_.hostnameAndPort)); + FabricAddress tmpAddr; + int ret = fi_av_insertsvc( + server->address_vector_, + hostname.c_str(), + std::to_string(port).c_str(), + &tmpAddr, + 0, + nullptr); + if (ret != 1) { + *error = ErrorFromFabricInternal("fi_av_insertsvc", ret); + return; + } + server->log__->Debug("Got handshake from " + hostname + ":" + std::to_string(port)); + + // TODO: This could slow down the whole complete queue process, maybe use another thread? + // Send and forget + server->HandleRawFiCommand(new FabricSelfDeletingTask(), error, + fi_send, nullptr, 0, nullptr, tmpAddr); +} + +void FabricHandshakeAcceptingTask::OnError(const Error* error) { + ServerContext()->log__->Warning("AsapoFabric FabricHandshakeAcceptingTask: " + (*error)->Explain()); +} diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h new file mode 100644 index 0000000000000000000000000000000000000000..74ffd3742fdc8435ae927da7f953ffde8eec3832 --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/task/fabric_handshake_accepting_task.h @@ -0,0 +1,38 @@ +#ifndef ASAPO_FABRIC_HANDSHAKE_ACCEPTING_TASK_H +#define ASAPO_FABRIC_HANDSHAKE_ACCEPTING_TASK_H + +#include "../../common/task/fabric_task.h" +#include "../../common/fabric_context_impl.h" + +namespace asapo { +namespace fabric { + +// Need forward declaration for reference inside the task +class FabricServerImpl; + +class FabricHandshakeAcceptingTask : public FabricSelfRequeuingTask { + private: + FabricHandshakePayload handshake_payload_{}; + + public: + explicit FabricHandshakeAcceptingTask(FabricServerImpl* server); + + private: + FabricServerImpl* ServerContext(); + + protected: // override FabricSelfRequeuingTask + void RequeueSelf() override; + + void OnCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; + + void OnErrorCompletion(const fi_cq_err_entry* errEntry) override; + + private: + void HandleAccept(Error* error); + void OnError(const Error* error); +}; + +} +} + +#endif //ASAPO_FABRIC_HANDSHAKE_ACCEPTING_TASK_H diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e703a43e6b49de8cbe602251f0d58ab48da599a --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.cpp @@ -0,0 +1,18 @@ +#include "fabric_recv_any_task.h" + +using namespace asapo; +using namespace fabric; + +void FabricRecvAnyTask::HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) { + messageId_ = entry->tag; + FabricWaitableTask::HandleCompletion(entry, source); +} + +void FabricRecvAnyTask::HandleErrorCompletion(const fi_cq_err_entry* errEntry) { + messageId_ = errEntry->tag; + FabricWaitableTask::HandleErrorCompletion(errEntry); +} + +FabricMessageId FabricRecvAnyTask::GetMessageId() const { + return messageId_; +} diff --git a/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h new file mode 100644 index 0000000000000000000000000000000000000000..06631082425891eb1d23a36ab9a45d42d0d7812f --- /dev/null +++ b/common/cpp/src/asapo_fabric/server/task/fabric_recv_any_task.h @@ -0,0 +1,25 @@ +#ifndef ASAPO_FABRIC_RECV_ANY_TASK_H +#define ASAPO_FABRIC_RECV_ANY_TASK_H + +#include <asapo_fabric/asapo_fabric.h> +#include <rdma/fi_eq.h> +#include "../../common/task/fabric_waitable_task.h" + +namespace asapo { +namespace fabric { + +class FabricRecvAnyTask : public FabricWaitableTask { + private: + FabricMessageId messageId_; + public: + void HandleCompletion(const fi_cq_tagged_entry* entry, FabricAddress source) override; + void HandleErrorCompletion(const fi_cq_err_entry* errEntry) override; + + FabricMessageId GetMessageId() const; +}; + +} +} + + +#endif //ASAPO_FABRIC_RECV_ANY_TASK_H diff --git a/common/cpp/src/system_io/system_io.cpp b/common/cpp/src/system_io/system_io.cpp index 3da71da9de4cc1a2ff66e09e8311955a27eadaf9..29c2130170aa3a17b521caabd88dc1dde4969fee 100644 --- a/common/cpp/src/system_io/system_io.cpp +++ b/common/cpp/src/system_io/system_io.cpp @@ -59,7 +59,8 @@ void AssignIDs(FileInfos* file_list) { } } -std::unique_ptr<std::tuple<std::string, uint16_t>> SystemIO::SplitAddressToHostnameAndPort(std::string address) const { +std::unique_ptr<std::tuple<std::string, uint16_t>> SystemIO::SplitAddressToHostnameAndPort( +const std::string& address) const { try { std::string host = address.substr(0, address.find(':')); diff --git a/common/cpp/src/system_io/system_io.h b/common/cpp/src/system_io/system_io.h index a97bb1152d465b1d40dcc02b1e4d73ee32d327c3..accbbeee0242cb13fc94df6ec7e0648e0de74d09 100644 --- a/common/cpp/src/system_io/system_io.h +++ b/common/cpp/src/system_io/system_io.h @@ -64,8 +64,6 @@ class SystemIO final : public IO { SocketDescriptor _accept(SocketDescriptor socket_fd, void* address, size_t* address_length) const; bool _close_socket(SocketDescriptor socket_fd) const; - std::unique_ptr<std::tuple<std::string, uint16_t>> SplitAddressToHostnameAndPort(std::string address) const; - std::unique_ptr<sockaddr_in> BuildSockaddrIn(const std::string& address, Error* err) const; /* @@ -130,6 +128,8 @@ class SystemIO final : public IO { void Skip(SocketDescriptor socket_fd, size_t length, Error* err) const override; void CloseSocket(SocketDescriptor socket_fd, Error* err) const override; std::string GetHostName(Error* err) const noexcept override; + std::unique_ptr<std::tuple<std::string, uint16_t>> SplitAddressToHostnameAndPort(const std::string& address) const + override; /* * Filesystem diff --git a/receiver/src/receiver_data_server/net_server.h b/receiver/src/receiver_data_server/net_server.h deleted file mode 100644 index ffdaaaf003a058127e41e03e69523d539f7761d8..0000000000000000000000000000000000000000 --- a/receiver/src/receiver_data_server/net_server.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef ASAPO_NET_SERVER_H -#define ASAPO_NET_SERVER_H - -#include "common/error.h" - -#include "request/request.h" - -namespace asapo { - -class NetServer { - public: - virtual GenericRequests GetNewRequests(Error* err) const noexcept = 0; - virtual Error SendData(uint64_t source_id, void* buf, uint64_t size) const noexcept = 0; - virtual void HandleAfterError(uint64_t source_id) const noexcept = 0; - virtual ~NetServer() = default; -}; - -} - -#endif //ASAPO_NET_SERVER_H diff --git a/receiver/src/receiver_data_server/rds_net_server.h b/receiver/src/receiver_data_server/rds_net_server.h new file mode 100644 index 0000000000000000000000000000000000000000..ee49cc4af4bedd97ef9d1416ed39143cb55859de --- /dev/null +++ b/receiver/src/receiver_data_server/rds_net_server.h @@ -0,0 +1,28 @@ +#ifndef ASAPO_RDS_NET_SERVER_H +#define ASAPO_RDS_NET_SERVER_H + +#include "../data_cache.h" +#include "common/error.h" +#include "receiver_data_server_request.h" + +namespace asapo { + +class RdsNetServer { + public: + /** + * It is very important that this function is successfully called, before any other call is is made! + */ + virtual Error Initialize() = 0; + virtual GenericRequests GetNewRequests(Error* err) = 0; + virtual Error SendResponse(const ReceiverDataServerRequest* request, + const GenericNetworkResponse* response) = 0; + virtual Error + SendResponseAndSlotData(const ReceiverDataServerRequest* request, const GenericNetworkResponse* response, + const CacheMeta* cache_slot) = 0; + virtual void HandleAfterError(uint64_t source_id) = 0; + virtual ~RdsNetServer() = default; +}; + +} + +#endif //ASAPO_RDS_NET_SERVER_H diff --git a/receiver/src/receiver_data_server/receiver_data_server.cpp b/receiver/src/receiver_data_server/receiver_data_server.cpp index eca610743ea0eaf3efc71e54b180095e1cc32b6a..81566167b717499b7a3d76c4bc79db6ae437f800 100644 --- a/receiver/src/receiver_data_server/receiver_data_server.cpp +++ b/receiver/src/receiver_data_server/receiver_data_server.cpp @@ -17,6 +17,13 @@ config_{config}, statistics__{new Statistics()} { } void ReceiverDataServer::Run() { + { + Error startError = net__->Initialize(); + if (startError) { + log__->Error(std::string("Error starting rds net server: ") + startError->Explain()); + return; + } + } while (true) { statistics__->SendIfNeeded(); Error err; @@ -34,4 +41,4 @@ void ReceiverDataServer::Run() { } } -} \ No newline at end of file +} diff --git a/receiver/src/receiver_data_server/receiver_data_server.h b/receiver/src/receiver_data_server/receiver_data_server.h index c7edda3f3ab16f1c0376f489d943bd96574ea252..889ce357c5fa43ab19a3c5874e3bc3c8a81a8eae 100644 --- a/receiver/src/receiver_data_server/receiver_data_server.h +++ b/receiver/src/receiver_data_server/receiver_data_server.h @@ -3,7 +3,7 @@ #include <memory> -#include "net_server.h" +#include "rds_net_server.h" #include "request/request_pool.h" #include "logger/logger.h" #include "../data_cache.h" @@ -21,7 +21,7 @@ class ReceiverDataServer { explicit ReceiverDataServer(std::string address, LogLevel log_level, SharedCache data_cache, const ReceiverDataCenterConfig& config); std::unique_ptr<RequestPool> request_pool__; - std::unique_ptr<NetServer> net__; + std::unique_ptr<RdsNetServer> net__; const AbstractLogger* log__; void Run(); private: diff --git a/receiver/src/receiver_data_server/receiver_data_server_request.h b/receiver/src/receiver_data_server/receiver_data_server_request.h index 0a6d24f63787594009b3b9fefa4ebdf452577780..541e08dddf45cf259861699a97e6f941bf5d9b6b 100644 --- a/receiver/src/receiver_data_server/receiver_data_server_request.h +++ b/receiver/src/receiver_data_server/receiver_data_server_request.h @@ -7,7 +7,7 @@ namespace asapo { -class NetServer; +class RdsNetServer; class ReceiverDataServerRequest : public GenericRequest { public: diff --git a/receiver/src/receiver_data_server/receiver_data_server_request_handler.cpp b/receiver/src/receiver_data_server/receiver_data_server_request_handler.cpp index 03d8ae764549afee64c6e3e5f4a555bcf2f67ac6..bc45e229eea4b6f21a9963488f7c2447cc91cf40 100644 --- a/receiver/src/receiver_data_server/receiver_data_server_request_handler.cpp +++ b/receiver/src/receiver_data_server/receiver_data_server_request_handler.cpp @@ -4,7 +4,7 @@ namespace asapo { -ReceiverDataServerRequestHandler::ReceiverDataServerRequestHandler(const NetServer* server, +ReceiverDataServerRequestHandler::ReceiverDataServerRequestHandler(RdsNetServer* server, DataCache* data_cache, Statistics* statistics): log__{GetDefaultReceiverDataServerLogger()}, statistics__{statistics}, server_{server}, data_cache_{data_cache} { @@ -15,59 +15,53 @@ bool ReceiverDataServerRequestHandler::CheckRequest(const ReceiverDataServerRequ return request->header.op_code == kOpcodeGetBufferData; } -Error ReceiverDataServerRequestHandler::SendData(const ReceiverDataServerRequest* request, - void* data, - CacheMeta* meta) { - auto err = SendResponce(request, kNetErrorNoError); - if (err) { - data_cache_->UnlockSlot(meta); - return err; - } - err = server_->SendData(request->source_id, data, request->header.data_size); - log__->Debug("sending data from memory cache, id:" + std::to_string(request->header.data_id)); - data_cache_->UnlockSlot(meta); - return err; +Error ReceiverDataServerRequestHandler::SendResponse(const ReceiverDataServerRequest* request, NetworkErrorCode code) { + GenericNetworkResponse response{}; + response.op_code = kOpcodeGetBufferData; + response.error_code = code; + return server_->SendResponse(request, &response); +} + +Error ReceiverDataServerRequestHandler::SendResponseAndSlotData(const ReceiverDataServerRequest* request, + const CacheMeta* meta) { + GenericNetworkResponse response{}; + response.op_code = kOpcodeGetBufferData; + response.error_code = kNetErrorNoError; + return server_->SendResponseAndSlotData(request, &response, + meta); } -void* ReceiverDataServerRequestHandler::GetSlot(const ReceiverDataServerRequest* request, CacheMeta** meta) { - void* buf = nullptr; +CacheMeta* ReceiverDataServerRequestHandler::GetSlotAndLock(const ReceiverDataServerRequest* request) { + CacheMeta* meta = nullptr; if (data_cache_) { - buf = data_cache_->GetSlotToReadAndLock(request->header.data_id, request->header.data_size, - meta); - if (!buf) { + data_cache_->GetSlotToReadAndLock(request->header.data_id, request->header.data_size, &meta); + if (!meta) { log__->Debug("data not found in memory cache, id:" + std::to_string(request->header.data_id)); } - - } - if (buf == nullptr) { - SendResponce(request, kNetErrorNoData); } - return buf; + return meta; } - bool ReceiverDataServerRequestHandler::ProcessRequestUnlocked(GenericRequest* request, bool* retry) { *retry = false; auto receiver_request = dynamic_cast<ReceiverDataServerRequest*>(request); if (!CheckRequest(receiver_request)) { - SendResponce(receiver_request, kNetErrorWrongRequest); - server_->HandleAfterError(receiver_request->source_id); - log__->Error("wrong request, code:" + std::to_string(receiver_request->header.op_code)); + HandleInvalidRequest(receiver_request); return true; } - CacheMeta* meta; - auto buf = GetSlot(receiver_request, &meta); - if (buf == nullptr) { + CacheMeta* meta = GetSlotAndLock(receiver_request); + if (!meta) { + SendResponse(receiver_request, kNetErrorNoData); return true; } - SendData(receiver_request, buf, meta); - statistics__->IncreaseRequestCounter(); - statistics__->IncreaseRequestDataVolume(receiver_request->header.data_size); + HandleValidRequest(receiver_request, meta); + data_cache_->UnlockSlot(meta); return true; } + bool ReceiverDataServerRequestHandler::ReadyProcessRequest() { return true; // always ready } @@ -76,19 +70,30 @@ void ReceiverDataServerRequestHandler::PrepareProcessingRequestLocked() { // do nothing } -void ReceiverDataServerRequestHandler::TearDownProcessingRequestLocked(bool processing_succeeded) { +void ReceiverDataServerRequestHandler::TearDownProcessingRequestLocked(bool /*processing_succeeded*/) { // do nothing } -Error ReceiverDataServerRequestHandler::SendResponce(const ReceiverDataServerRequest* request, NetworkErrorCode code) { - GenericNetworkResponse responce; - responce.op_code = kOpcodeGetBufferData; - responce.error_code = code; - return server_->SendData(request->source_id, &responce, sizeof(GenericNetworkResponse)); +void ReceiverDataServerRequestHandler::ProcessRequestTimeout(GenericRequest* /*request*/) { +// do nothing } -void ReceiverDataServerRequestHandler::ProcessRequestTimeout(GenericRequest* request) { -// do nothing +void ReceiverDataServerRequestHandler::HandleInvalidRequest(const ReceiverDataServerRequest* receiver_request) { + SendResponse(receiver_request, kNetErrorWrongRequest); + server_->HandleAfterError(receiver_request->source_id); + log__->Error("wrong request, code:" + std::to_string(receiver_request->header.op_code)); } -} \ No newline at end of file +void ReceiverDataServerRequestHandler::HandleValidRequest(const ReceiverDataServerRequest* receiver_request, + const CacheMeta* meta) { + auto err = SendResponseAndSlotData(receiver_request, meta); + if (err) { + log__->Error("failed to send slot:" + err->Explain()); + server_->HandleAfterError(receiver_request->source_id); + } else { + statistics__->IncreaseRequestCounter(); + statistics__->IncreaseRequestDataVolume(receiver_request->header.data_size); + } +} + +} diff --git a/receiver/src/receiver_data_server/receiver_data_server_request_handler.h b/receiver/src/receiver_data_server/receiver_data_server_request_handler.h index 4bae2ee45e590dc1d3e699d72e90675c04cc2b9a..44abf4ea54f772b50ffac3fff13757b5a9977481 100644 --- a/receiver/src/receiver_data_server/receiver_data_server_request_handler.h +++ b/receiver/src/receiver_data_server/receiver_data_server_request_handler.h @@ -2,7 +2,7 @@ #define ASAPO_RECEIVER_DATA_SERVER_REQUEST_HANDLER_H #include "request/request_handler.h" -#include "net_server.h" +#include "rds_net_server.h" #include "../data_cache.h" #include "receiver_data_server_request.h" #include "receiver_data_server_logger.h" @@ -12,7 +12,7 @@ namespace asapo { class ReceiverDataServerRequestHandler: public RequestHandler { public: - explicit ReceiverDataServerRequestHandler(const NetServer* server, DataCache* data_cache, Statistics* statistics); + explicit ReceiverDataServerRequestHandler(RdsNetServer* server, DataCache* data_cache, Statistics* statistics); bool ProcessRequestUnlocked(GenericRequest* request, bool* retry) override; bool ReadyProcessRequest() override; void PrepareProcessingRequestLocked() override; @@ -22,12 +22,16 @@ class ReceiverDataServerRequestHandler: public RequestHandler { const AbstractLogger* log__; Statistics* statistics__; private: - const NetServer* server_; + RdsNetServer* server_; DataCache* data_cache_; bool CheckRequest(const ReceiverDataServerRequest* request); - Error SendResponce(const ReceiverDataServerRequest* request, NetworkErrorCode code); - Error SendData(const ReceiverDataServerRequest* request, void* data, CacheMeta* meta); - void* GetSlot(const ReceiverDataServerRequest* request, CacheMeta** meta); + Error SendResponse(const ReceiverDataServerRequest* request, NetworkErrorCode code); + Error SendResponseAndSlotData(const ReceiverDataServerRequest* request, const CacheMeta* meta); + CacheMeta* GetSlotAndLock(const ReceiverDataServerRequest* request); + + void HandleInvalidRequest(const ReceiverDataServerRequest* receiver_request); + + void HandleValidRequest(const ReceiverDataServerRequest* receiver_request, const CacheMeta* meta); }; } diff --git a/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.cpp b/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.cpp index d3e259ff1ffa0940a3887eb6b1e2159fd69a33e2..d08f13a8c05089475e887f7c9ca3038ca848e28c 100644 --- a/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.cpp +++ b/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.cpp @@ -8,9 +8,9 @@ std::unique_ptr<RequestHandler> ReceiverDataServerRequestHandlerFactory::NewRequ uint64_t* shared_counter) { return std::unique_ptr<RequestHandler> {new ReceiverDataServerRequestHandler(server_, data_cache_, statistics_)}; } -ReceiverDataServerRequestHandlerFactory::ReceiverDataServerRequestHandlerFactory(const NetServer* server, +ReceiverDataServerRequestHandlerFactory::ReceiverDataServerRequestHandlerFactory(RdsNetServer* server, DataCache* data_cache, Statistics* statistics) : server_{server}, data_cache_{data_cache}, statistics_{statistics} { } -} \ No newline at end of file +} diff --git a/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.h b/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.h index b84f80fc8b562041b78cd97f8fee5897fc8d2807..29b26dbcefafb27f74786d3371a3c6e6a60b7ca1 100644 --- a/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.h +++ b/receiver/src/receiver_data_server/receiver_data_server_request_handler_factory.h @@ -5,7 +5,7 @@ #include "request/request_handler.h" #include "preprocessor/definitions.h" -#include "net_server.h" +#include "rds_net_server.h" #include "../data_cache.h" #include "../statistics.h" @@ -13,10 +13,10 @@ namespace asapo { class ReceiverDataServerRequestHandlerFactory : public RequestHandlerFactory { public: - ReceiverDataServerRequestHandlerFactory (const NetServer* server, DataCache* data_cache, Statistics* statistics); + ReceiverDataServerRequestHandlerFactory(RdsNetServer* server, DataCache* data_cache, Statistics* statistics); VIRTUAL std::unique_ptr<RequestHandler> NewRequestHandler(uint64_t thread_id, uint64_t* shared_counter) override; private: - const NetServer* server_; + RdsNetServer* server_; DataCache* data_cache_; Statistics* statistics_; }; diff --git a/receiver/src/receiver_data_server/tcp_server.cpp b/receiver/src/receiver_data_server/tcp_server.cpp index 4a924745f9cab11081790b2b52b2fcb6b6e88776..bd8bb807c387151e60e5c25eabdbe70edbe680fa 100644 --- a/receiver/src/receiver_data_server/tcp_server.cpp +++ b/receiver/src/receiver_data_server/tcp_server.cpp @@ -9,7 +9,7 @@ namespace asapo { TcpServer::TcpServer(std::string address) : io__{GenerateDefaultIO()}, log__{GetDefaultReceiverDataServerLogger()}, address_{std::move(address)} {} -Error TcpServer::InitializeMasterSocketIfNeeded() const noexcept { +Error TcpServer::Initialize() { Error err; if (master_socket_ == kDisconnectedSocketDescriptor) { master_socket_ = io__->CreateAndBindIPTCPSocketListener(address_, kMaxPendingConnections, &err); @@ -18,11 +18,13 @@ Error TcpServer::InitializeMasterSocketIfNeeded() const noexcept { } else { log__->Error("dataserver cannot listen on " + address_ + ": " + err->Explain()); } + } else { + err = TextError("Server was already initialized"); } return err; } -ListSocketDescriptors TcpServer::GetActiveSockets(Error* err) const noexcept { +ListSocketDescriptors TcpServer::GetActiveSockets(Error* err) { std::vector<std::string> new_connections; auto sockets = io__->WaitSocketsActivity(master_socket_, &sockets_to_listen_, &new_connections, err); for (auto& connection : new_connections) { @@ -31,14 +33,14 @@ ListSocketDescriptors TcpServer::GetActiveSockets(Error* err) const noexcept { return sockets; } -void TcpServer::CloseSocket(SocketDescriptor socket) const noexcept { +void TcpServer::CloseSocket(SocketDescriptor socket) { sockets_to_listen_.erase(std::remove(sockets_to_listen_.begin(), sockets_to_listen_.end(), socket), sockets_to_listen_.end()); log__->Debug("connection " + io__->AddressFromSocket(socket) + " closed"); io__->CloseSocket(socket, nullptr); } -ReceiverDataServerRequestPtr TcpServer::ReadRequest(SocketDescriptor socket, Error* err) const noexcept { +ReceiverDataServerRequestPtr TcpServer::ReadRequest(SocketDescriptor socket, Error* err) { GenericRequestHeader header; io__->Receive(socket, &header, sizeof(GenericRequestHeader), err); @@ -51,10 +53,10 @@ ReceiverDataServerRequestPtr TcpServer::ReadRequest(SocketDescriptor socket, Err ); return nullptr; } - return ReceiverDataServerRequestPtr{new ReceiverDataServerRequest{std::move(header), (uint64_t) socket}}; + return ReceiverDataServerRequestPtr{new ReceiverDataServerRequest{header, (uint64_t) socket}}; } -GenericRequests TcpServer::ReadRequests(const ListSocketDescriptors& sockets) const noexcept { +GenericRequests TcpServer::ReadRequests(const ListSocketDescriptors& sockets) { GenericRequests requests; for (auto client : sockets) { Error err; @@ -69,11 +71,7 @@ GenericRequests TcpServer::ReadRequests(const ListSocketDescriptors& sockets) co return requests; } -GenericRequests TcpServer::GetNewRequests(Error* err) const noexcept { - if ( (*err = InitializeMasterSocketIfNeeded()) ) { - return {}; - } - +GenericRequests TcpServer::GetNewRequests(Error* err) { auto sockets = GetActiveSockets(err); if (*err) { return {}; @@ -90,18 +88,34 @@ TcpServer::~TcpServer() { io__->CloseSocket(master_socket_, nullptr); } +void TcpServer::HandleAfterError(uint64_t source_id) { + CloseSocket(source_id); +} -Error TcpServer::SendData(uint64_t source_id, void* buf, uint64_t size) const noexcept { +Error TcpServer::SendResponse(const ReceiverDataServerRequest* request, const GenericNetworkResponse* response) { Error err; - io__->Send(source_id, buf, size, &err); + io__->Send(request->source_id, response, sizeof(*response), &err); if (err) { log__->Error("cannot send to consumer" + err->Explain()); } return err; } -void TcpServer::HandleAfterError(uint64_t source_id) const noexcept { - CloseSocket(source_id); +Error +TcpServer::SendResponseAndSlotData(const ReceiverDataServerRequest* request, const GenericNetworkResponse* response, + const CacheMeta* cache_slot) { + Error err; + + err = SendResponse(request, response); + if (err) { + return err; + } + + io__->Send(request->source_id, cache_slot->addr, cache_slot->size, &err); + if (err) { + log__->Error("cannot send slot to worker" + err->Explain()); + } + return err; } -} \ No newline at end of file +} diff --git a/receiver/src/receiver_data_server/tcp_server.h b/receiver/src/receiver_data_server/tcp_server.h index af4f9c6579a85e28ff8aae946f2c0c37a503147f..bf27acafc1d50ac9219cad6bd66f96c3a314352a 100644 --- a/receiver/src/receiver_data_server/tcp_server.h +++ b/receiver/src/receiver_data_server/tcp_server.h @@ -1,7 +1,7 @@ -#ifndef ASAPO_TCP_SERVER_H -#define ASAPO_TCP_SERVER_H +#ifndef ASAPO_RDS_TCP_SERVER_H +#define ASAPO_RDS_TCP_SERVER_H -#include "net_server.h" +#include "rds_net_server.h" #include "io/io.h" #include "logger/logger.h" #include "receiver_data_server_request.h" @@ -9,26 +9,31 @@ namespace asapo { const int kMaxPendingConnections = 5; -class TcpServer : public NetServer { +class TcpServer : public RdsNetServer { public: - TcpServer(std::string address); - ~TcpServer(); - GenericRequests GetNewRequests(Error* err) const noexcept override ; - Error SendData(uint64_t source_id, void* buf, uint64_t size) const noexcept override; - void HandleAfterError(uint64_t source_id) const noexcept override; + explicit TcpServer(std::string address); + ~TcpServer() override; + + Error Initialize() override; + + GenericRequests GetNewRequests(Error* err) override; + Error SendResponse(const ReceiverDataServerRequest* request, + const GenericNetworkResponse* response) override; + Error SendResponseAndSlotData(const ReceiverDataServerRequest* request, const GenericNetworkResponse* response, + const CacheMeta* cache_slot) override; + void HandleAfterError(uint64_t source_id) override; std::unique_ptr<IO> io__; const AbstractLogger* log__; private: - void CloseSocket(SocketDescriptor socket) const noexcept; - ListSocketDescriptors GetActiveSockets(Error* err) const noexcept; - Error InitializeMasterSocketIfNeeded() const noexcept; - ReceiverDataServerRequestPtr ReadRequest(SocketDescriptor socket, Error* err) const noexcept; - GenericRequests ReadRequests(const ListSocketDescriptors& sockets) const noexcept; - mutable SocketDescriptor master_socket_{kDisconnectedSocketDescriptor}; - mutable ListSocketDescriptors sockets_to_listen_; + void CloseSocket(SocketDescriptor socket); + ListSocketDescriptors GetActiveSockets(Error* err); + ReceiverDataServerRequestPtr ReadRequest(SocketDescriptor socket, Error* err) ; + GenericRequests ReadRequests(const ListSocketDescriptors& sockets); + SocketDescriptor master_socket_{kDisconnectedSocketDescriptor}; + ListSocketDescriptors sockets_to_listen_; std::string address_; }; } -#endif //ASAPO_TCP_SERVER_H +#endif //ASAPO_RDS_TCP_SERVER_H diff --git a/receiver/unittests/receiver_data_server/receiver_dataserver_mocking.h b/receiver/unittests/receiver_data_server/receiver_dataserver_mocking.h index 6eacdb6c706bd8c47e13b757b2704930fee179ed..5db8ab20591ab3b28368aae72e317d326b2154e3 100644 --- a/receiver/unittests/receiver_data_server/receiver_dataserver_mocking.h +++ b/receiver/unittests/receiver_data_server/receiver_dataserver_mocking.h @@ -4,15 +4,20 @@ #include <gtest/gtest.h> #include <gmock/gmock.h> -#include "../../src/receiver_data_server/net_server.h" +#include "../../src/receiver_data_server/rds_net_server.h" #include "request/request_pool.h" #include "../../src/receiver_data_server/receiver_data_server_request.h" namespace asapo { -class MockNetServer : public NetServer { +class MockNetServer : public RdsNetServer { public: - GenericRequests GetNewRequests(Error* err) const noexcept override { + Error Initialize() override { + return Error{Initialize_t()}; + }; + MOCK_METHOD0(Initialize_t, ErrorInterface * ()); + + GenericRequests GetNewRequests(Error* err) override { ErrorInterface* error = nullptr; auto reqs = GetNewRequests_t(&error); err->reset(error); @@ -24,21 +29,29 @@ class MockNetServer : public NetServer { return res; } - MOCK_CONST_METHOD1(GetNewRequests_t, std::vector<ReceiverDataServerRequest> (ErrorInterface** - error)); - - Error SendData(uint64_t source_id, void* buf, uint64_t size) const noexcept override { - return Error{SendData_t(source_id, buf, size)}; + MOCK_METHOD1(GetNewRequests_t, std::vector<ReceiverDataServerRequest> (ErrorInterface** + error)); + Error SendResponse(const ReceiverDataServerRequest* request, + const GenericNetworkResponse* response) override { + return Error{SendResponse_t(request, response)}; }; + MOCK_METHOD2(SendResponse_t, ErrorInterface * (const ReceiverDataServerRequest* request, + const GenericNetworkResponse* response)); - MOCK_CONST_METHOD3(SendData_t, ErrorInterface * (uint64_t source_id, void* buf, uint64_t size)); + Error SendResponseAndSlotData(const ReceiverDataServerRequest* request, const GenericNetworkResponse* response, + const CacheMeta* cache_slot) override { + return Error{SendResponseAndSlotData_t(request, response, cache_slot)}; + }; + MOCK_METHOD3(SendResponseAndSlotData_t, ErrorInterface * (const ReceiverDataServerRequest* request, + const GenericNetworkResponse* response, + const CacheMeta* cache_slot)); - void HandleAfterError(uint64_t source_id) const noexcept override { + void HandleAfterError(uint64_t source_id) override { HandleAfterError_t(source_id); } - MOCK_CONST_METHOD1(HandleAfterError_t, void (uint64_t source_id)); + MOCK_METHOD1(HandleAfterError_t, void (uint64_t source_id)); }; class MockPool : public RequestPool { @@ -47,7 +60,7 @@ class MockPool : public RequestPool { Error AddRequests(GenericRequests requests) noexcept override { std::vector<GenericRequest> reqs; for (const auto& preq : requests) { - reqs.push_back(GenericRequest{preq->header, 0}); + reqs.emplace_back(preq->header, 0); } return Error(AddRequests_t(std::move(reqs))); diff --git a/receiver/unittests/receiver_data_server/test_receiver_data_server.cpp b/receiver/unittests/receiver_data_server/test_receiver_data_server.cpp index 50017308ce4856b558270470ff3a65d9a1431a5d..17f798369b240d4bff4a2fa0c39f6b758b71eabb 100644 --- a/receiver/unittests/receiver_data_server/test_receiver_data_server.cpp +++ b/receiver/unittests/receiver_data_server/test_receiver_data_server.cpp @@ -59,7 +59,7 @@ class ReceiverDataServerTests : public Test { NiceMock<asapo::MockLogger> mock_logger; NiceMock<asapo::MockStatistics> mock_statistics; void SetUp() override { - data_server.net__ = std::unique_ptr<asapo::NetServer> {&mock_net}; + data_server.net__ = std::unique_ptr<asapo::RdsNetServer> {&mock_net}; data_server.request_pool__ = std::unique_ptr<asapo::RequestPool> {&mock_pool}; data_server.log__ = &mock_logger; data_server.statistics__ = std::unique_ptr<asapo::Statistics> {&mock_statistics};; diff --git a/receiver/unittests/receiver_data_server/test_request_handler.cpp b/receiver/unittests/receiver_data_server/test_request_handler.cpp index fa5b8d8dd1dd9e38d42eb20eb39d982e68b75564..804999e1b79723968d549e57d2ddc3abafa01c9b 100644 --- a/receiver/unittests/receiver_data_server/test_request_handler.cpp +++ b/receiver/unittests/receiver_data_server/test_request_handler.cpp @@ -23,15 +23,14 @@ using ::testing::_; using ::testing::SetArgPointee; using ::testing::NiceMock; using ::testing::HasSubstr; - +using ::testing::DoAll; using asapo::ReceiverDataServer; using asapo::ReceiverDataServerRequestHandler; - namespace { -MATCHER_P3(M_CheckResponce, op_code, error_code, message, +MATCHER_P3(M_CheckResponse, op_code, error_code, message, "Checks if a valid GenericNetworkResponse was used") { return ((asapo::GenericNetworkResponse*)arg)->op_code == op_code && ((asapo::GenericNetworkResponse*)arg)->error_code == uint64_t(error_code); @@ -57,6 +56,7 @@ class RequestHandlerTests : public Test { uint64_t expected_meta_size = 100; uint64_t expected_buf_id = 12345; uint64_t expected_source_id = 11; + asapo::CacheMeta expected_meta; bool retry; asapo::GenericRequestHeader header{asapo::kOpcodeGetBufferData, expected_buf_id, expected_data_size, expected_meta_size, ""}; @@ -67,22 +67,41 @@ class RequestHandlerTests : public Test { } void TearDown() override { } - void MockGetSlot(bool ok = true); - void MockSendResponce(asapo::NetworkErrorCode err_code, bool ok = true); + void MockGetSlotAndUnlockIt(bool return_without_error = true); + void MockSendResponse(asapo::NetworkErrorCode expected_response_code, bool return_without_error); + void MockSendResponseAndSlotData(asapo::NetworkErrorCode expected_response_code, bool return_without_error); + }; -void RequestHandlerTests::MockGetSlot(bool ok) { - EXPECT_CALL(mock_cache, GetSlotToReadAndLock(expected_buf_id, expected_data_size, _)).WillOnce( - Return(ok ? &tmp : nullptr) - ); +void RequestHandlerTests::MockGetSlotAndUnlockIt(bool return_without_error) { + EXPECT_CALL(mock_cache, GetSlotToReadAndLock(expected_buf_id, expected_data_size, _)).WillOnce(DoAll( + SetArgPointee<2>(return_without_error ? &expected_meta : nullptr), + Return(return_without_error ? &tmp : nullptr) + )); + if (return_without_error) { + EXPECT_CALL(mock_cache, UnlockSlot(_)); + } } -void RequestHandlerTests::MockSendResponce(asapo::NetworkErrorCode err_code, bool ok) { - EXPECT_CALL(mock_net, SendData_t(expected_source_id, - M_CheckResponce(asapo::kOpcodeGetBufferData, err_code, ""), sizeof(asapo::GenericNetworkResponse))).WillOnce( - Return(ok ? nullptr : asapo::IOErrorTemplates::kUnknownIOError.Generate().release()) - ); +void RequestHandlerTests::MockSendResponse(asapo::NetworkErrorCode expected_response_code, bool return_without_error) { + EXPECT_CALL(mock_net, SendResponse_t( + &request, + M_CheckResponse(asapo::kOpcodeGetBufferData, expected_response_code, "") + )).WillOnce( + Return(return_without_error ? nullptr : asapo::IOErrorTemplates::kUnknownIOError.Generate().release()) + ); +} + +void RequestHandlerTests::MockSendResponseAndSlotData(asapo::NetworkErrorCode expected_response_code, + bool return_without_error) { + EXPECT_CALL(mock_net, SendResponseAndSlotData_t( + &request, + M_CheckResponse(asapo::kOpcodeGetBufferData, expected_response_code, ""), + &expected_meta + )).WillOnce( + Return(return_without_error ? nullptr : asapo::IOErrorTemplates::kUnknownIOError.Generate().release()) + ); } TEST_F(RequestHandlerTests, RequestAlwaysReady) { @@ -91,9 +110,9 @@ TEST_F(RequestHandlerTests, RequestAlwaysReady) { ASSERT_THAT(res, Eq(true)); } -TEST_F(RequestHandlerTests, ProcessRequest_WronOpCode) { +TEST_F(RequestHandlerTests, ProcessRequest_WrongOpCode) { request.header.op_code = asapo::kOpcodeUnknownOp; - MockSendResponce(asapo::kNetErrorWrongRequest, false); + MockSendResponse(asapo::kNetErrorWrongRequest, false); EXPECT_CALL(mock_net, HandleAfterError_t(expected_source_id)); EXPECT_CALL(mock_logger, Error(HasSubstr("wrong request"))); @@ -103,8 +122,8 @@ TEST_F(RequestHandlerTests, ProcessRequest_WronOpCode) { ASSERT_THAT(success, Eq(true)); } -TEST_F(RequestHandlerTests, ProcessRequestReturnsNoDataWhenCacheNotUsed) { - MockSendResponce(asapo::kNetErrorNoData, true); +TEST_F(RequestHandlerTests, ProcessRequest_ReturnsNoDataWhenCacheNotUsed) { + MockSendResponse(asapo::kNetErrorNoData, true); auto success = handler_no_cache.ProcessRequestUnlocked(&request, &retry); EXPECT_CALL(mock_logger, Debug(_)).Times(0); @@ -112,9 +131,9 @@ TEST_F(RequestHandlerTests, ProcessRequestReturnsNoDataWhenCacheNotUsed) { ASSERT_THAT(success, Eq(true)); } -TEST_F(RequestHandlerTests, ProcessRequestReadSlotReturnsNull) { - MockGetSlot(false); - MockSendResponce(asapo::kNetErrorNoData, true); +TEST_F(RequestHandlerTests, ProcessRequest_ReadSlotReturnsNull) { + MockGetSlotAndUnlockIt(false); + MockSendResponse(asapo::kNetErrorNoData, true); EXPECT_CALL(mock_logger, Debug(HasSubstr("not found"))); auto success = handler.ProcessRequestUnlocked(&request, &retry); @@ -122,28 +141,19 @@ TEST_F(RequestHandlerTests, ProcessRequestReadSlotReturnsNull) { ASSERT_THAT(success, Eq(true)); } - -TEST_F(RequestHandlerTests, ProcessRequestReadSlotErrorSendingResponce) { - MockGetSlot(true); - MockSendResponce(asapo::kNetErrorNoError, false); - EXPECT_CALL(mock_net, SendData_t(expected_source_id, &tmp, expected_data_size)).Times(0); - EXPECT_CALL(mock_cache, UnlockSlot(_)); +TEST_F(RequestHandlerTests, ProcessRequest_ReadSlotErrorSendingResponse) { + MockGetSlotAndUnlockIt(true); + MockSendResponseAndSlotData(asapo::kNetErrorNoError, false); + EXPECT_CALL(mock_net, HandleAfterError_t(_)); auto success = handler.ProcessRequestUnlocked(&request, &retry); ASSERT_THAT(success, Eq(true)); } - - -TEST_F(RequestHandlerTests, ProcessRequestOk) { - MockGetSlot(true); - MockSendResponce(asapo::kNetErrorNoError, true); - EXPECT_CALL(mock_net, SendData_t(expected_source_id, &tmp, expected_data_size)).WillOnce( - Return(nullptr) - ); - EXPECT_CALL(mock_cache, UnlockSlot(_)); - EXPECT_CALL(mock_logger, Debug(HasSubstr("sending"))); +TEST_F(RequestHandlerTests, ProcessRequest_Ok) { + MockGetSlotAndUnlockIt(true); + MockSendResponseAndSlotData(asapo::kNetErrorNoError, true); EXPECT_CALL(mock_stat, IncreaseRequestCounter_t()); EXPECT_CALL(mock_stat, IncreaseRequestDataVolume_t(expected_data_size)); auto success = handler.ProcessRequestUnlocked(&request, &retry); diff --git a/receiver/unittests/receiver_data_server/test_request_handler_factory.cpp b/receiver/unittests/receiver_data_server/test_request_handler_factory.cpp index 770804c8f908d0d5dfbeb9074984df07a94862bb..44ecc6462c46854fc698e0d54be6cc4f77807c79 100644 --- a/receiver/unittests/receiver_data_server/test_request_handler_factory.cpp +++ b/receiver/unittests/receiver_data_server/test_request_handler_factory.cpp @@ -35,7 +35,7 @@ TEST(ReceiverDataServerRequestHandlerFactory, Constructor) { config.nthreads = 4; ReceiverDataServer data_server{"", asapo::LogLevel::Debug, nullptr, config}; asapo::Statistics stat; - ReceiverDataServerRequestHandlerFactory factory((asapo::NetServer*)&data_server, nullptr, &stat); + ReceiverDataServerRequestHandlerFactory factory((asapo::RdsNetServer*)&data_server, nullptr, &stat); auto handler = factory.NewRequestHandler(1, nullptr); ASSERT_THAT(dynamic_cast<const asapo::ReceiverDataServerRequestHandler*>(handler.get()), Ne(nullptr)); } diff --git a/receiver/unittests/receiver_data_server/test_tcp_server.cpp b/receiver/unittests/receiver_data_server/test_tcp_server.cpp index 9b8985ef3f03f68a85344fa01c398864e5f29b86..76c6caf5eee3070caeaaa600c7771490143d8bb4 100644 --- a/receiver/unittests/receiver_data_server/test_tcp_server.cpp +++ b/receiver/unittests/receiver_data_server/test_tcp_server.cpp @@ -59,14 +59,15 @@ class TCPServerTests : public Test { void TearDown() override { tcp_server.io__.release(); } - void ExpectListenMaster(bool ok); + void ExpectTcpBind(bool ok); void WaitSockets(bool ok, ListSocketDescriptors clients = {}); - void MockReceiveRequest(bool ok ); + void MockReceiveRequest(bool ok); + void InitMasterServer(); void ExpectReceiveOk(); void ExpectReceiveRequestEof(); }; -void TCPServerTests::ExpectListenMaster(bool ok) { +void TCPServerTests::ExpectTcpBind(bool ok) { EXPECT_CALL(mock_io, CreateAndBindIPTCPSocketListener_t(expected_address, asapo::kMaxPendingConnections, _)) .WillOnce(DoAll( SetArgPointee<2>(ok ? nullptr : asapo::IOErrorTemplates::kUnknownIOError.Generate().release()), @@ -93,24 +94,28 @@ void TCPServerTests::WaitSockets(bool ok, ListSocketDescriptors clients) { } } -TEST_F(TCPServerTests, GetNewRequestsInitializesSocket) { - Error err; - ExpectListenMaster(false); +void TCPServerTests::InitMasterServer() { + ExpectTcpBind(true); + ASSERT_THAT(tcp_server.Initialize(), Eq(nullptr)); +} - auto requests = tcp_server.GetNewRequests(&err); +TEST_F(TCPServerTests, Initialize_Error) { + ExpectTcpBind(false); + + Error err = tcp_server.Initialize(); ASSERT_THAT(err, Ne(nullptr)); - ASSERT_THAT(requests, IsEmpty()); } -TEST_F(TCPServerTests, GetNewRequestsInitializesSocketOnlyOnce) { +TEST_F(TCPServerTests, Initialize_ErrorDoubleInitialize) { Error err; - ExpectListenMaster(false); - tcp_server.GetNewRequests(&err); - tcp_server.GetNewRequests(&err); + ExpectTcpBind(true); + err = tcp_server.Initialize(); + ASSERT_THAT(err, Eq(nullptr)); -// ASSERT_THAT(err, Ne(nullptr)); + err = tcp_server.Initialize(); + ASSERT_THAT(err, Ne(nullptr)); } void TCPServerTests::MockReceiveRequest(bool ok ) { @@ -162,10 +167,9 @@ void TCPServerTests::ExpectReceiveOk() { } } - TEST_F(TCPServerTests, GetNewRequestsWaitsSocketActivitiesError) { Error err; - ExpectListenMaster(true); + InitMasterServer(); WaitSockets(false); auto requests = tcp_server.GetNewRequests(&err); @@ -176,7 +180,7 @@ TEST_F(TCPServerTests, GetNewRequestsWaitsSocketActivitiesError) { TEST_F(TCPServerTests, GetNewRequestsWaitsSocketReceiveFailure) { Error err; - ExpectListenMaster(true); + InitMasterServer(); WaitSockets(true); MockReceiveRequest(false); @@ -194,7 +198,7 @@ TEST_F(TCPServerTests, GetNewRequestsWaitsSocketReceiveFailure) { TEST_F(TCPServerTests, GetNewRequestsReadEof) { Error err; - ExpectListenMaster(true); + InitMasterServer(); WaitSockets(true); ExpectReceiveRequestEof(); @@ -212,9 +216,8 @@ TEST_F(TCPServerTests, GetNewRequestsReadEof) { } TEST_F(TCPServerTests, GetNewRequestsReadOk) { - Error - err; - ExpectListenMaster(true); + Error err; + InitMasterServer(); WaitSockets(true); ExpectReceiveOk(); @@ -234,10 +237,11 @@ TEST_F(TCPServerTests, GetNewRequestsReadOk) { } -TEST_F(TCPServerTests, SendData) { - uint8_t tmp; +TEST_F(TCPServerTests, SendResponse) { + asapo::GenericNetworkResponse tmp {}; + asapo::ReceiverDataServerRequest expectedRequest {{}, 30}; - EXPECT_CALL(mock_io, Send_t(1, &tmp, 10, _)) + EXPECT_CALL(mock_io, Send_t(30, &tmp, sizeof(asapo::GenericNetworkResponse), _)) .WillOnce( DoAll( testing::SetArgPointee<3>(asapo::IOErrorTemplates::kUnknownIOError.Generate().release()), @@ -246,11 +250,80 @@ TEST_F(TCPServerTests, SendData) { EXPECT_CALL(mock_logger, Error(HasSubstr("cannot send"))); - auto err = tcp_server.SendData(1, &tmp, 10); + auto err = tcp_server.SendResponse(&expectedRequest, &tmp); + + ASSERT_THAT(err, Ne(nullptr)); +} + +TEST_F(TCPServerTests, SendResponseAndSlotData_SendResponseError) { + asapo::GenericNetworkResponse tmp {}; + + + asapo::ReceiverDataServerRequest expectedRequest {{}, 30}; + asapo::CacheMeta expectedMeta {}; + expectedMeta.id = 20; + expectedMeta.addr = (void*)0x9234; + expectedMeta.size = 50; + expectedMeta.lock = 123; + + EXPECT_CALL(mock_io, Send_t(30, &tmp, sizeof(asapo::GenericNetworkResponse), _)) + .WillOnce(DoAll( + testing::SetArgPointee<3>(asapo::IOErrorTemplates::kUnknownIOError.Generate().release()), + Return(0) + )); + EXPECT_CALL(mock_logger, Error(HasSubstr("cannot send"))); + + auto err = tcp_server.SendResponseAndSlotData(&expectedRequest, &tmp, &expectedMeta); + + ASSERT_THAT(err, Ne(nullptr)); +} + +TEST_F(TCPServerTests, SendResponseAndSlotData_SendDataError) { + asapo::GenericNetworkResponse tmp {}; + + asapo::ReceiverDataServerRequest expectedRequest {{}, 30}; + asapo::CacheMeta expectedMeta {}; + expectedMeta.id = 20; + expectedMeta.addr = (void*)0x9234; + expectedMeta.size = 50; + expectedMeta.lock = 123; + + EXPECT_CALL(mock_io, Send_t(30, &tmp, sizeof(asapo::GenericNetworkResponse), _)) + .WillOnce(Return(1)); + EXPECT_CALL(mock_io, Send_t(30, expectedMeta.addr, expectedMeta.size, _)) + .WillOnce( + DoAll( + testing::SetArgPointee<3>(asapo::IOErrorTemplates::kUnknownIOError.Generate().release()), + Return(0) + )); + + EXPECT_CALL(mock_logger, Error(HasSubstr("cannot send"))); + + auto err = tcp_server.SendResponseAndSlotData(&expectedRequest, &tmp, &expectedMeta); ASSERT_THAT(err, Ne(nullptr)); } +TEST_F(TCPServerTests, SendResponseAndSlotData_Ok) { + asapo::GenericNetworkResponse tmp {}; + + asapo::ReceiverDataServerRequest expectedRequest {{}, 30}; + asapo::CacheMeta expectedMeta {}; + expectedMeta.id = 20; + expectedMeta.addr = (void*)0x9234; + expectedMeta.size = 50; + expectedMeta.lock = 123; + + EXPECT_CALL(mock_io, Send_t(30, &tmp, sizeof(asapo::GenericNetworkResponse), _)) + .WillOnce(Return(1)); + EXPECT_CALL(mock_io, Send_t(30, expectedMeta.addr, expectedMeta.size, _)) + .WillOnce(Return(expectedMeta.size)); + + auto err = tcp_server.SendResponseAndSlotData(&expectedRequest, &tmp, &expectedMeta); + + ASSERT_THAT(err, Eq(nullptr)); +} + TEST_F(TCPServerTests, HandleAfterError) { EXPECT_CALL(mock_io, CloseSocket_t(expected_client_sockets[0], _)); tcp_server.HandleAfterError(expected_client_sockets[0]); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 30598b059e866cf2fb886e6f3e9ce32167f2b2cd..11b4622c6528a19a79daa5807874d7f8093862ea 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -4,3 +4,5 @@ add_subdirectory(automatic) configure_files(${CMAKE_CURRENT_SOURCE_DIR}/manual/tests_via_nomad ${CMAKE_CURRENT_BINARY_DIR}/manual/tests_via_nomad @ONLY) add_subdirectory(manual/performance_broker_receiver) + +add_subdirectory(manual/asapo_fabric) diff --git a/tests/automatic/CMakeLists.txt b/tests/automatic/CMakeLists.txt index 7e926f7928815a5d64b056fb3bd22442a6ce462e..9df59ca928b4dcae80fdf225a5f317419b3dbdc0 100644 --- a/tests/automatic/CMakeLists.txt +++ b/tests/automatic/CMakeLists.txt @@ -37,3 +37,7 @@ if (UNIX) endif() add_subdirectory(bug_fixes) + +if (ENABLE_LIBFABRIC) + add_subdirectory(asapo_fabric) +endif() diff --git a/tests/automatic/asapo_fabric/CMakeLists.txt b/tests/automatic/asapo_fabric/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c8da6e1c73e90e79bf3fdcd4674ef1a2042cf5aa --- /dev/null +++ b/tests/automatic/asapo_fabric/CMakeLists.txt @@ -0,0 +1,19 @@ +# Automatically add all files to the tests +file(GLOB files "*.cpp") +foreach(file ${files}) + # File name without extension + get_filename_component(file_we "${file}" NAME_WE) + set(TARGET_NAME test-auto-asapo_fabric-${file_we}) + set(SOURCE_FILES ${file}) + + GET_PROPERTY(ASAPO_COMMON_FABRIC_LIBRARIES GLOBAL PROPERTY ASAPO_COMMON_FABRIC_LIBRARIES) + + # Executable and link + add_executable(${TARGET_NAME} ${SOURCE_FILES} $<TARGET_OBJECTS:logger> $<TARGET_OBJECTS:curl_http_client>) + target_link_libraries(${TARGET_NAME} test_common asapo-fabric ${CURL_LIBRARIES} ${ASAPO_COMMON_FABRIC_LIBRARIES}) + target_include_directories(${TARGET_NAME} PUBLIC ${ASAPO_CXX_COMMON_INCLUDE_DIR}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX) + + # Add test + add_integration_test(${TARGET_NAME} ${TARGET_NAME} "") +endforeach() diff --git a/tests/automatic/asapo_fabric/client_lazy_initialization.cpp b/tests/automatic/asapo_fabric/client_lazy_initialization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..751385618ff5cdb6d5869d05ea90134431c11120 --- /dev/null +++ b/tests/automatic/asapo_fabric/client_lazy_initialization.cpp @@ -0,0 +1,25 @@ +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <testing.h> + +using namespace asapo; +using namespace fabric; + +int main(int argc, char* argv[]) { + Error err; + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + M_AssertEq("", client->GetAddress()); + + int dummyBuffer = 0; + auto mr = client->ShareMemoryRegion(&dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kClientNotInitializedError, err, "client->ShareMemoryRegion"); + err = nullptr; + + // Other methods require an serverAddress which initializes the client + + return 0; +} diff --git a/tests/automatic/asapo_fabric/parallel_data_transfer.cpp b/tests/automatic/asapo_fabric/parallel_data_transfer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc99295efc270c7b364c28946a1b19363bf1f2ef --- /dev/null +++ b/tests/automatic/asapo_fabric/parallel_data_transfer.cpp @@ -0,0 +1,164 @@ +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <testing.h> +#include <thread> +#include <iostream> +#include <cstring> +#include <future> +#include <request/request.h> + +using namespace asapo; +using namespace fabric; + +std::promise<void> clientIsDone; +std::future<void> clientIsDoneFuture = clientIsDone.get_future(); + +std::promise<void> serverIsDone; +std::future<void> serverIsDoneFuture = serverIsDone.get_future(); + +constexpr size_t kRdmaSize = 5 * 1024 * 1024; +constexpr int kServerThreads = 2; +constexpr int kEachInstanceRuns = 10; +constexpr int kClientThreads = 4; + +void ServerChildThread(FabricServer* server, std::atomic<int>* serverTotalRequests, char* expectedRdmaBuffer) { + constexpr int maxRuns = kClientThreads * kEachInstanceRuns; + Error err; + + while ((*serverTotalRequests)++ < maxRuns) { + GenericRequestHeader request{}; + + FabricAddress clientAddress; + FabricMessageId messageId; + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "server->RecvAny"); + M_AssertEq("Hello World", request.message); + M_AssertEq(messageId / kEachInstanceRuns, request.data_id); // is client index + M_AssertEq(messageId % kEachInstanceRuns, request.data_size); // is client run + + server->RdmaWrite(clientAddress, (MemoryRegionDetails*)&request.substream, expectedRdmaBuffer, kRdmaSize, &err); + M_AssertEq(nullptr, err, "server->RdmaWrite"); + + GenericNetworkResponse response{}; + strcpy(response.message, "Hey, I am the Server"); + server->Send(clientAddress, messageId, &response, sizeof(response), &err); + M_AssertEq(nullptr, err, "server->Send"); + } + + std::cerr << "A Server is done" << std::endl; +} + +void ServerMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + std::atomic<int> serverTotalRequests(0); + + std::thread threads[kServerThreads]; + for (auto& thread : threads) { + thread = std::thread(ServerChildThread, server.get(), &serverTotalRequests, expectedRdmaBuffer); + } + + for (auto& thread : threads) { + thread.join(); + } + + std::cerr << "[SERVER] Waiting for all client to finish" << std::endl; + clientIsDoneFuture.get(); + serverIsDone.set_value(); +} + +void ClientChildThread(const std::string& hostname, uint16_t port, int index, char* expectedRdmaBuffer) { + auto factory = GenerateDefaultFabricFactory(); + Error err; + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); + M_AssertEq(nullptr, err, "client->AddServerAddress"); + + auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + + auto mr = client->ShareMemoryRegion(actualRdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(nullptr, err, "client->ShareMemoryRegion"); + + for (int run = 0; run < kEachInstanceRuns; run++) { + std::cerr << "Client run: " << run << std::endl; + + GenericRequestHeader request{}; + strcpy(request.message, "Hello World"); + memcpy(request.substream, mr->GetDetails(), sizeof(MemoryRegionDetails)); + request.data_id = index; + request.data_size = run; + FabricMessageId messageId = (index * kEachInstanceRuns) + run; + client->Send(serverAddress, messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "client->Send"); + + GenericNetworkResponse response{}; + client->Recv(serverAddress, messageId, &response, sizeof(response), &err); + M_AssertEq(nullptr, err, "client->Recv"); + M_AssertEq("Hey, I am the Server", response.message); + + for (size_t i = 0; i < kRdmaSize; i++) { + // Check to reduce log spam + if (expectedRdmaBuffer[i] != actualRdmaBuffer[i]) { + M_AssertEq(expectedRdmaBuffer[i], actualRdmaBuffer[i], + "Expect rdma[i] == acutal[i], i = " + std::to_string(i)); + } + } + } + std::cout << "A Client is done" << std::endl; +} + +void ClientMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { + std::thread threads[kClientThreads]; + for (int i = 0; i < kClientThreads; i++) { + threads[i] = std::thread(ClientChildThread, hostname, port, i, expectedRdmaBuffer); + } + + for (auto& thread : threads) { + thread.join(); + } + + clientIsDone.set_value(); + std::cout << "[Client] Waiting for server to finish" << std::endl; + serverIsDoneFuture.get(); +} + +int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + std::cout << "Client is writing to std::cout" << std::endl; + std::cerr << "Server is writing to std::cerr" << std::endl; + + auto expectedRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + for (size_t i = 0; i < kRdmaSize; i++) { + expectedRdmaBuffer[i] = (char)i; + } + + std::thread serverMasterThread(ServerMasterThread, hostname, port, expectedRdmaBuffer.get()); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + ClientMasterThread(hostname, port, expectedRdmaBuffer.get()); + + std::cout << "Done testing. Joining server" << std::endl; + serverMasterThread.join(); + + return 0; +} diff --git a/tests/automatic/asapo_fabric/server_not_running.cpp b/tests/automatic/asapo_fabric/server_not_running.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa11d69365ca5fff380bea6fa77b8700bcafd720 --- /dev/null +++ b/tests/automatic/asapo_fabric/server_not_running.cpp @@ -0,0 +1,35 @@ +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <testing.h> +#include <iostream> + +using namespace asapo; +using namespace fabric; + +int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + Error err; + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); + M_AssertEq(FabricErrorTemplates::kConnectionRefusedError, err, "client->AddServerAddress"); + err = nullptr; + + return 0; +} diff --git a/tests/automatic/asapo_fabric/simple_data_transfer.cpp b/tests/automatic/asapo_fabric/simple_data_transfer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c35f5a2cadc7a021e20a606906c288698ed0871 --- /dev/null +++ b/tests/automatic/asapo_fabric/simple_data_transfer.cpp @@ -0,0 +1,135 @@ +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <testing.h> +#include <thread> +#include <iostream> +#include <cstring> +#include <future> +#include <request/request.h> + +using namespace asapo; +using namespace fabric; + +std::promise<void> clientIsDone; +std::future<void> clientIsDoneFuture = clientIsDone.get_future(); + +std::promise<void> serverIsDone; +std::future<void> serverIsDoneFuture = serverIsDone.get_future(); + +constexpr int kTotalRuns = 3; +constexpr int kEachInstanceRuns = 5; +constexpr size_t kRdmaSize = 5 * 1024 * 1024; + +void ServerMasterThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { + { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + + for (int run = 0; run < kTotalRuns; run++) { + for (int instanceRuns = 0; instanceRuns < kEachInstanceRuns; instanceRuns++) { + GenericRequestHeader request{}; + + FabricAddress clientAddress; + FabricMessageId messageId; + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "server->RecvAny"); + M_AssertEq(123 + instanceRuns, messageId); + M_AssertEq("Hello World", request.message); + + server->RdmaWrite(clientAddress, (MemoryRegionDetails*) &request.substream, expectedRdmaBuffer, kRdmaSize, + &err); + M_AssertEq(nullptr, err, "server->RdmaWrite"); + + GenericNetworkResponse response{}; + strcpy(response.message, "Hey, I am the Server"); + server->Send(clientAddress, messageId, &response, sizeof(response), &err); + M_AssertEq(nullptr, err, "server->Send"); + } + } + + std::cout << "[SERVER] Waiting for client to finish" << std::endl; + clientIsDoneFuture.get(); + } + std::cout << "[SERVER] Server is done" << std::endl; + serverIsDone.set_value(); +} + +void ClientThread(const std::string& hostname, uint16_t port, char* expectedRdmaBuffer) { + Error err; + + for (int run = 0; run < kTotalRuns; run++) { + std::cout << "Running client " << run << std::endl; + for (int instanceRuns = 0; instanceRuns < kEachInstanceRuns; instanceRuns++) { + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); + M_AssertEq(nullptr, err, "client->AddServerAddress"); + + auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + + auto mr = client->ShareMemoryRegion(actualRdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(nullptr, err, "client->ShareMemoryRegion"); + + GenericRequestHeader request{}; + strcpy(request.message, "Hello World"); + memcpy(request.substream, mr->GetDetails(), sizeof(MemoryRegionDetails)); + FabricMessageId messageId = 123 + instanceRuns; + client->Send(serverAddress, messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "client->Send"); + + GenericNetworkResponse response{}; + client->Recv(serverAddress, messageId, &response, sizeof(response), &err); + M_AssertEq(nullptr, err, "client->Recv"); + M_AssertEq("Hey, I am the Server", response.message); + + for (size_t i = 0; i < kRdmaSize; i++) { + // Check to reduce log spam + if (expectedRdmaBuffer[i] != actualRdmaBuffer[i]) { + M_AssertEq(expectedRdmaBuffer[i], actualRdmaBuffer[i], + "Expect rdma[i] == acutal[i], i = " + std::to_string(i)); + } + } + } + } + + clientIsDone.set_value(); + serverIsDoneFuture.get(); +} + +int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + auto expectedRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + for (size_t i = 0; i < kRdmaSize; i++) { + expectedRdmaBuffer[i] = (char)i; + } + + std::thread serverThread(ServerMasterThread, hostname, port, expectedRdmaBuffer.get()); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + ClientThread(hostname, port, expectedRdmaBuffer.get()); + + std::cout << "Done testing. Joining server" << std::endl; + serverThread.join(); + + return 0; +} diff --git a/tests/automatic/asapo_fabric/timeout_test.cpp b/tests/automatic/asapo_fabric/timeout_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..42abd640883fb8304d00bc4fd982d63acf304d1e --- /dev/null +++ b/tests/automatic/asapo_fabric/timeout_test.cpp @@ -0,0 +1,107 @@ +#include <iostream> +#include <future> +#include <common/error.h> +#include <logger/logger.h> +#include <testing.h> +#include <asapo_fabric/asapo_fabric.h> + +using namespace asapo; +using namespace fabric; + +std::promise<void> serverShutdown; +std::future<void> serverShutdown_future = serverShutdown.get_future(); + +std::promise<void> serverShutdownAck; +std::future<void> serverShutdownAck_future = serverShutdownAck.get_future(); + +void ServerMasterThread(const std::string& hostname, uint16_t port) { + { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); + + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + + // Wait for client to send a request and then shutdown the server + int dummyBuffer; + FabricAddress clientAddress; + FabricMessageId messageId; + server->RecvAny(&clientAddress, &messageId, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "server->RecvAny"); + + server->Send(clientAddress, messageId, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "server->Send"); + + serverShutdown_future.wait(); + } + + printf("Server is now down!\n"); + serverShutdownAck.set_value(); +} + +void ClientThread(const std::string& hostname, uint16_t port) { + Error err; + + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); + M_AssertEq(nullptr, err, "client->AddServerAddress"); + + int dummyBuffer = 0; + client->Send(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "client->Send"); + + client->Recv(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(nullptr, err, "client->Recv"); + + // Server should not respond to this message + std::cout << + "The following call might take a while since its able to reach the server but the server is not responding" + << std::endl; + client->Recv(serverAddress, 0, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kTimeout, err, "client->Recv"); + err = nullptr; + + serverShutdown.set_value(); + serverShutdownAck_future.wait(); + + // Server is now down + client->Recv(serverAddress, 1, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kInternalConnectionError, err, "client->Recv"); + err = nullptr; + + client->Send(serverAddress, 2, &dummyBuffer, sizeof(dummyBuffer), &err); + M_AssertEq(FabricErrorTemplates::kInternalConnectionError, err, "client->Send"); + err = nullptr; +} + +int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + std::thread serverThread(ServerMasterThread, hostname, port); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + ClientThread(hostname, port); + + std::cout << "Done testing. Joining server" << std::endl; + serverThread.join(); + + return 0; +} diff --git a/tests/automatic/asapo_fabric/wrong_memory_info.cpp b/tests/automatic/asapo_fabric/wrong_memory_info.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f94fc6c658136570b9eba5d650facfd10f4d5886 --- /dev/null +++ b/tests/automatic/asapo_fabric/wrong_memory_info.cpp @@ -0,0 +1,146 @@ +#include <common/error.h> +#include <asapo_fabric/asapo_fabric.h> +#include <testing.h> +#include <thread> +#include <iostream> +#include <cstring> +#include <future> +#include <request/request.h> + +using namespace asapo; +using namespace fabric; + +std::promise<void> clientIsDone; +std::future<void> clientIsDoneFuture = clientIsDone.get_future(); + +std::promise<void> serverIsDone; +std::future<void> serverIsDoneFuture = serverIsDone.get_future(); + +constexpr size_t kRdmaSize = 5 * 1024; +constexpr size_t kDummyDataSize = 512; + +void ServerMasterThread(const std::string& hostname, uint16_t port) { + Error err; + auto log = CreateDefaultLoggerBin("AutomaticTesting"); + + auto factory = GenerateDefaultFabricFactory(); + auto server = factory->CreateAndBindServer(log.get(), hostname, port, &err); + M_AssertEq(nullptr, err, "factory->CreateAndBindServer"); + + GenericRequestHeader request{}; + + auto rdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + auto dummyData = std::unique_ptr<char[]>(new char[kDummyDataSize]); + + FabricAddress clientAddress; + FabricMessageId messageId; + + // Simulate faulty memory details + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "server->RecvAny(1)"); + M_AssertEq(1, messageId); + M_AssertEq("Hello World", request.message); + server->RdmaWrite(clientAddress, (MemoryRegionDetails*)&request.substream, rdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(FabricErrorTemplates::kInternalError, err, "server->RdmaWrite(1)"); + err = nullptr; // We have to reset the error by ourselves + server->Send(clientAddress, messageId, dummyData.get(), kDummyDataSize, &err); + M_AssertEq(nullptr, err, "server->Send(1)"); + + // Simulate correct memory details + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "server->RecvAny(2)"); + M_AssertEq(2, messageId); + server->RdmaWrite(clientAddress, (MemoryRegionDetails*)&request.substream, rdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(nullptr, err, "server->RdmaWrite(2)"); + server->Send(clientAddress, messageId, dummyData.get(), kDummyDataSize, &err); + M_AssertEq(nullptr, err, "server->Send(2)"); + + // Simulate old (unregistered) memory details + GenericRequestHeader request2{}; + server->RecvAny(&clientAddress, &messageId, &request2, sizeof(request2), &err); + M_AssertEq(nullptr, err, "server->RecvAny(3)"); + M_AssertEq(3, messageId); + server->RdmaWrite(clientAddress, (MemoryRegionDetails*)&request.substream, rdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(FabricErrorTemplates::kInternalError, err, "server->RdmaWrite(3)"); + + std::cout << "[SERVER] Waiting for client to finish" << std::endl; + clientIsDoneFuture.get(); + serverIsDone.set_value(); +} + +void ClientThread(const std::string& hostname, uint16_t port) { + Error err; + + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&err); + M_AssertEq(nullptr, err, "factory->CreateClient"); + + auto serverAddress = client->AddServerAddress(hostname + ":" + std::to_string(port), &err); + M_AssertEq(nullptr, err, "client->AddServerAddress"); + + auto actualRdmaBuffer = std::unique_ptr<char[]>(new char[kRdmaSize]); + auto dummyData = std::unique_ptr<char[]>(new char[kDummyDataSize]); + + GenericRequestHeader request{}; + FabricMessageId messageId = 1; + strcpy(request.message, "Hello World"); + + // Scoped MemoryRegion + { + auto mr = client->ShareMemoryRegion(actualRdmaBuffer.get(), kRdmaSize, &err); + M_AssertEq(nullptr, err, "client->ShareMemoryRegion"); + memcpy(request.substream, mr->GetDetails(), sizeof(MemoryRegionDetails)); + + // Simulate faulty memory details + ((MemoryRegionDetails*)(&request.substream))->key++; + client->Send(serverAddress, messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "client->Send(1)"); + client->Recv(serverAddress, messageId, dummyData.get(), kDummyDataSize, &err); + M_AssertEq(nullptr, err, "client->Recv(1)"); + messageId++; + + // Simulate correct memory details + memcpy(request.substream, mr->GetDetails(), sizeof(MemoryRegionDetails)); + client->Send(serverAddress, messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "client->Send(2)"); + client->Recv(serverAddress, messageId, dummyData.get(), kDummyDataSize, &err); + M_AssertEq(nullptr, err, "client->Recv(2)"); + messageId++; + } + + // Simulate old (unregistered) memory details + // Details are still written from "Simulate correct memory details" + client->Send(serverAddress, messageId, &request, sizeof(request), &err); + M_AssertEq(nullptr, err, "client->Send(3)"); + + clientIsDone.set_value(); + std::cout << "[Client] Waiting for server to finish" << std::endl; + serverIsDoneFuture.get(); +} + +int main(int argc, char* argv[]) { + std::string hostname = "127.0.0.1"; + uint16_t port = 1816; + + if (argc > 3) { + std::cout << "Usage: " << argv[0] << " [<host>] [<port>]" << std::endl; + return 1; + } + if (argc == 2) { + hostname = argv[1]; + } + if (argc == 3) { + port = (uint16_t) strtoul(argv[2], nullptr, 10); + } + + std::thread serverThread(ServerMasterThread, hostname, port); + + std::this_thread::sleep_for(std::chrono::seconds(2)); + ClientThread(hostname, port); + + std::cout << "Done testing. Joining server" << std::endl; + serverThread.join(); + + return 0; +} diff --git a/tests/automatic/common/cpp/include/testing.h b/tests/automatic/common/cpp/include/testing.h index d127dfd931f4cafb6f0667d5e841657b2eede329..8b99b81d181c1162fbe9c5eeb4ae4cc45f198851 100644 --- a/tests/automatic/common/cpp/include/testing.h +++ b/tests/automatic/common/cpp/include/testing.h @@ -37,17 +37,17 @@ inline void _M_AssertEq(const ErrorTemplateType& expected, const Error& got, con #define _M_AssertEq_2_ARGS(e, g) \ asapo::_M_AssertEq(e, g, _M_INTERNAL_COMMENT_PREFIX "Expect " # g " to be " # e) #define _M_AssertEq_3_ARGS(e, g, c) \ - asapo::_M_AssertEq(e, g, _M_INTERNAL_COMMENT_PREFIX c) + asapo::_M_AssertEq(e, g, std::string(_M_INTERNAL_COMMENT_PREFIX) + std::string(c)) #define _M_AssertContains_2_ARGS(whole, sub) \ asapo::_M_AssertContains(whole, sub, _M_INTERNAL_COMMENT_PREFIX "Expect " # whole " to contain substring " # sub) #define _M_AssertContains_3_ARGS(whole, sub, c) \ - asapo::_M_AssertContains(whole, sub, _M_INTERNAL_COMMENT_PREFIX c) + asapo::_M_AssertContains(whole, sub, std::string(_M_INTERNAL_COMMENT_PREFIX) + std::string(c)) #define _M_AssertTrue_1_ARGS(value) \ asapo::_M_AssertTrue(value, _M_INTERNAL_COMMENT_PREFIX "Expect " # value " to be true") #define _M_AssertTrue_2_ARGS(value, c) \ - asapo::_M_AssertTrue(value, _M_INTERNAL_COMMENT_PREFIX c) + asapo::_M_AssertTrue(value, std::string(_M_INTERNAL_COMMENT_PREFIX) + std::string(c)) #define _M_GET_4TH_ARG(arg1, arg2, arg3, arg4, ...) arg4 #define _M_MACRO_CHOOSER_2_3(func, ...) _M_GET_4TH_ARG(__VA_ARGS__, _ ## func ## _3_ARGS, _ ## func ## _2_ARGS, ) diff --git a/tests/automatic/producer/python_api/check_windows.bat b/tests/automatic/producer/python_api/check_windows.bat index a354590701ce1694749c6f1b8fab5d879153a3fd..aec0dd80126a6b8a868105c866f68d3396f72b06 100644 --- a/tests/automatic/producer/python_api/check_windows.bat +++ b/tests/automatic/producer/python_api/check_windows.bat @@ -24,7 +24,7 @@ set NUM=0 for /F %%N in ('find /C "successfuly sent" ^< "out"') do set NUM=%%N echo %NUM% | findstr 10 || goto error -for /F %%N in ('find /C "} wrong input: Bad request :already have record with same id" ^< "out"') do set NUM=%%N +for /F %%N in ('find /C "} wrong input: Bad request: already have record with same id" ^< "out"') do set NUM=%%N echo %NUM% | findstr 2 || goto error for /F %%N in ('find /C "} server warning: ignoring duplicate record" ^< "out"') do set NUM=%%N diff --git a/tests/manual/asapo_fabric/CMakeLists.txt b/tests/manual/asapo_fabric/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca48a6f6a30529e2ef79f47e2e25c9fb09d87cd7 --- /dev/null +++ b/tests/manual/asapo_fabric/CMakeLists.txt @@ -0,0 +1,7 @@ +GET_PROPERTY(ASAPO_COMMON_FABRIC_LIBRARIES GLOBAL PROPERTY ASAPO_COMMON_FABRIC_LIBRARIES) + +add_executable(example-fabric-server fabric_server.cpp $<TARGET_OBJECTS:logger> $<TARGET_OBJECTS:curl_http_client>) +target_link_libraries(example-fabric-server asapo-fabric ${CURL_LIBRARIES} ${ASAPO_COMMON_FABRIC_LIBRARIES}) + +add_executable(example-fabric-client fabric_client.cpp) +target_link_libraries(example-fabric-client asapo-fabric ${ASAPO_COMMON_FABRIC_LIBRARIES}) diff --git a/tests/manual/asapo_fabric/fabric_client.cpp b/tests/manual/asapo_fabric/fabric_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..792a8293ff7d5cad7a9945fb68fcd9b04b28f758 --- /dev/null +++ b/tests/manual/asapo_fabric/fabric_client.cpp @@ -0,0 +1,83 @@ +#include <asapo_fabric/asapo_fabric.h> +#include <iostream> +#include <common/data_structs.h> +#include <common/networking.h> + +using namespace asapo; +using namespace asapo::fabric; + +int main(int argc, char* argv[]) { + if (argc != 3) { + std::cout + << "Usage: " << argv[0] << " <serverAddress> <serverPort>" << std::endl + << "If the address is localhost or 127.0.0.1 the verbs connection will be emulated" << std::endl + ; + return 1; + } + + std::string serverAddressString = std::string(argv[1]) + ':' + std::string(argv[2]); + + Error error; + auto factory = GenerateDefaultFabricFactory(); + + auto client = factory->CreateClient(&error); + if (error) { + std::cout << "Client exited with error: " << error << std::endl; + return 1; + } + + size_t dataBufferSize = 1024 * 1024 * 400 /*400 MiByte*/; + FileData dataBuffer = FileData{new uint8_t[dataBufferSize]}; + + auto serverAddress = client->AddServerAddress(serverAddressString, &error); + if (error) { + std::cout << "Client exited with error: " << error << std::endl; + return 1; + } + std::cout << "Added server address" << std::endl; + + auto mr = client->ShareMemoryRegion(dataBuffer.get(), dataBufferSize, &error); + if (error) { + std::cout << "Client exited with error: " << error << std::endl; + return 1; + } + + uint64_t totalTransferSize = 0; + auto start = std::chrono::high_resolution_clock::now(); + + std::cout << "Starting message loop" << std::endl; + for (FabricMessageId messageId = 0; messageId < 10 && !error; messageId++) { + GenericRequestHeader request{}; + memcpy(&request.message, mr->GetDetails(), sizeof(MemoryRegionDetails)); + client->Send(serverAddress, messageId, &request, sizeof(request), &error); + if (error) { + break; + } + + GenericNetworkResponse response{}; + client->Recv(serverAddress, messageId, &response, sizeof(response), &error); + if (error) { + break; + } + + if (strcmp((char*)dataBuffer.get(), "I (the server) wrote into your buffer.") != 0) { + error = TextError("The buffer was not written with the expected text"); + break; + } + memset(dataBuffer.get(), 0, 64); + + totalTransferSize += dataBufferSize; + } + auto end = std::chrono::high_resolution_clock::now(); + + if (error) { + std::cout << "Client exited with error: " << error << std::endl; + return 1; + } + + auto timeTook = end - start; + std::cout << "Transferred " << (((totalTransferSize) / 1024) / 1024) << " MiBytes in " + << std::chrono::duration_cast<std::chrono::milliseconds>(timeTook).count() << "ms" << std::endl; + + return 0; +} diff --git a/tests/manual/asapo_fabric/fabric_server.cpp b/tests/manual/asapo_fabric/fabric_server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb973398bbc07b93cf75fecc9972235382596e5a --- /dev/null +++ b/tests/manual/asapo_fabric/fabric_server.cpp @@ -0,0 +1,85 @@ +#include <asapo_fabric/asapo_fabric.h> +#include <iostream> +#include <io/io_factory.h> +#include <common/networking.h> + +using namespace asapo; +using namespace asapo::fabric; + +volatile bool running = false; + +void ServerThread(FabricServer* server, size_t bufferSize, FileData* buffer) { + Error error; + while(running && !error) { + FabricAddress clientAddress; + FabricMessageId messageId; + GenericRequestHeader request; + + server->RecvAny(&clientAddress, &messageId, &request, sizeof(request), &error); + if (error == FabricErrorTemplates::kTimeout) { + error = nullptr; + continue; + } + if (error) { + break; + } + + std::cout << "Got a request from " << clientAddress << " id: " << messageId << std::endl; + server->RdmaWrite(clientAddress, (MemoryRegionDetails*)&request.message, buffer->get(), bufferSize, &error); + + GenericNetworkResponse response{}; + server->Send(clientAddress, messageId, &response, sizeof(response), &error); + } + + if (error) { + std::cerr << "Server thread exited with an error: " << error << std::endl; + } +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + std::cout + << "Usage: " << argv[0] << " <listenAddress> <listenPort>" << std::endl + << "If the address is localhost or 127.0.0.1 the verbs connection will be emulated" << std::endl + ; + return 1; + } + + Error error; + auto io = GenerateDefaultIO(); + auto factory = GenerateDefaultFabricFactory(); + Logger log = CreateDefaultLoggerBin("FabricTestServer"); + + uint16_t port = (uint16_t)strtoul(argv[2], nullptr, 10); + auto server = factory->CreateAndBindServer(log.get(), argv[1], port, &error); + if (error) { + std::cerr << error << std::endl; + return 1; + } + + std::cout << "Server is listening on " << server->GetAddress() << std::endl; + + size_t dataBufferSize = 1024 * 1024 * 400 /*400 MiByte*/; + FileData dataBuffer = FileData{new uint8_t[dataBufferSize]}; + strcpy((char*)dataBuffer.get(), "I (the server) wrote into your buffer."); + + running = true; + auto thread = io->NewThread("ServerThread", [&server, &dataBufferSize, &dataBuffer]() { + ServerThread(server.get(), dataBufferSize, &dataBuffer); + }); + + std::cout << "Press Enter to stop the server." << std::endl; + + getchar(); + std::cout << "Stopping server... Please wait until the RecvAny is timing out." << std::endl; + + running = false; + thread->join(); + + if (error) { + std::cerr << "Client exited with error: " << error << std::endl; + return 1; + } + + return 0; +} diff --git a/tests/valgrind.suppressions b/tests/valgrind.suppressions index e2ee0f0f241350918e7f7bb919b427e712a9e5b8..2c21d104a09bbf2e4a911e3f38cdbda5e5e16516 100644 --- a/tests/valgrind.suppressions +++ b/tests/valgrind.suppressions @@ -90,3 +90,36 @@ fun:*_M_mutate* ... } +{ + asapo_fabric__sockets__uninitialised_bytes + Memcheck:Param + socketcall.sendto(msg) + fun:send + fun:ofi_send_socket +} +{ + asapo_fabric__sockets__leak + Memcheck:Leak + match-leak-kinds: definite + fun:calloc + fun:sock_rx_new_buffered_entry + fun:sock_pe_process_rx_send + fun:sock_pe_process_recv + fun:sock_pe_progress_rx_pe_entry + fun:sock_pe_progress_rx_ctx + fun:sock_pe_progress_thread + fun:start_thread + fun:clone +} +{ + asapo_fabric__verbs__rdma_accept__uninitialised_bytes + Memcheck:Param + write(buf) + obj:*libpthread* + fun:rdma_accept + fun:vrb_msg_ep_accept + fun:fi_accept + fun:rxm_msg_process_connreq +} + +