diff --git a/common/cpp/include/unittests/MockFabric.h b/common/cpp/include/unittests/MockFabric.h index b677df0baf15da1445a92103dbf167eb33a8f571..9e393161c38a0edca4e4b78921ddff40704f9570 100644 --- a/common/cpp/include/unittests/MockFabric.h +++ b/common/cpp/include/unittests/MockFabric.h @@ -2,11 +2,18 @@ #define ASAPO_MOCKFABRIC_H #include <asapo_fabric/asapo_fabric.h> +#include <gmock/gmock.h> namespace asapo { namespace fabric { class MockFabricMemoryRegion : public FabricMemoryRegion { + public: + MockFabricMemoryRegion() = default; + ~MockFabricMemoryRegion() override { + Destructor(); + } + MOCK_METHOD0(Destructor, void()); MOCK_CONST_METHOD0(GetDetails, const MemoryRegionDetails * ()); }; @@ -52,6 +59,7 @@ class MockFabricContext : public FabricContext { }; class MockFabricClient : public MockFabricContext, public FabricClient { + public: FabricAddress AddServerAddress(const std::string& serverAddress, Error* error) override { ErrorInterface* err = nullptr; auto data = AddServerAddress_t(serverAddress, &err); @@ -59,6 +67,30 @@ class MockFabricClient : public MockFabricContext, public FabricClient { return data; } MOCK_METHOD2(AddServerAddress_t, FabricAddress (const std::string& serverAddress, ErrorInterface** err)); + public: // Link to FabricContext + std::string GetAddress() const override { + return MockFabricContext::GetAddress(); + } + + std::unique_ptr<FabricMemoryRegion> ShareMemoryRegion(void* src, size_t size, Error* error) override { + return MockFabricContext::ShareMemoryRegion(src, size, error); + } + + void Send(FabricAddress dstAddress, FabricMessageId messageId, + const void* src, size_t size, Error* error) override { + MockFabricContext::Send(dstAddress, messageId, src, size, error); + } + + void Recv(FabricAddress srcAddress, FabricMessageId messageId, + void* dst, size_t size, Error* error) override { + MockFabricContext::Recv(srcAddress, messageId, dst, size, error); + } + + void RdmaWrite(FabricAddress dstAddress, + const MemoryRegionDetails* details, const void* buffer, size_t size, + Error* error) override { + MockFabricContext::RdmaWrite(dstAddress, details, buffer, size, error); + } }; class MockFabricServer : public MockFabricContext, public FabricServer { diff --git a/consumer/api/cpp/CMakeLists.txt b/consumer/api/cpp/CMakeLists.txt index 12f123bc55292894253fe732cbca0c36a487b051..ca518f319f726527db3c8d6698d8ba5ffed3ddf9 100644 --- a/consumer/api/cpp/CMakeLists.txt +++ b/consumer/api/cpp/CMakeLists.txt @@ -5,7 +5,7 @@ set(SOURCE_FILES src/server_data_broker.cpp src/tcp_client.cpp src/tcp_connection_pool.cpp - src/fabric_client.cpp) + src/fabric_consumer_client.cpp) ################################ @@ -21,17 +21,21 @@ IF(CMAKE_C_COMPILER_ID STREQUAL "GNU") ENDIF() -target_link_libraries(${TARGET_NAME} ${CURL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) +GET_PROPERTY(ASAPO_COMMON_FABRIC_LIBRARIES GLOBAL PROPERTY ASAPO_COMMON_FABRIC_LIBRARIES) +target_link_libraries(${TARGET_NAME} ${CURL_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} + asapo-fabric ${ASAPO_COMMON_FABRIC_LIBRARIES}) ################################ # Testing ################################ -set(TEST_SOURCE_FILES unittests/test_consumer_api.cpp - unittests/test_server_broker.cpp - unittests/test_tcp_client.cpp - unittests/test_tcp_connection_pool.cpp - ) +set(TEST_SOURCE_FILES + unittests/test_consumer_api.cpp + unittests/test_server_broker.cpp + unittests/test_tcp_client.cpp + unittests/test_tcp_connection_pool.cpp + unittests/test_fabric_consumer_client.cpp + ) set(TEST_LIBRARIES "${TARGET_NAME}") diff --git a/consumer/api/cpp/src/fabric_client.cpp b/consumer/api/cpp/src/fabric_consumer_client.cpp similarity index 58% rename from consumer/api/cpp/src/fabric_client.cpp rename to consumer/api/cpp/src/fabric_consumer_client.cpp index dde6e0dd836389b79de86bf2f26418f544347414..50c9eee6c18f192f58f42f2c34b43a68886756a5 100644 --- a/consumer/api/cpp/src/fabric_client.cpp +++ b/consumer/api/cpp/src/fabric_consumer_client.cpp @@ -1,14 +1,15 @@ #include <common/networking.h> #include <io/io_factory.h> -#include "fabric_client.h" +#include <iostream> +#include "fabric_consumer_client.h" using namespace asapo; -FabricClient::FabricClient(): factory__(fabric::GenerateDefaultFabricFactory()), io__{GenerateDefaultIO()} { +FabricConsumerClient::FabricConsumerClient(): factory__(fabric::GenerateDefaultFabricFactory()) { } -Error FabricClient::GetData(const FileInfo* info, FileData* data) { +Error FabricConsumerClient::GetData(const FileInfo* info, FileData* data) { Error err; if (!client__) { client__ = factory__->CreateClient(&err); @@ -22,25 +23,19 @@ Error FabricClient::GetData(const FileInfo* info, FileData* data) { return err; } + FileData tempData{new uint8_t[info->size]}; + /* MemoryRegion will be released when out of scope */ - auto mr = client__->ShareMemoryRegion(data->get(), info->size, &err); + auto mr = client__->ShareMemoryRegion(tempData.get(), info->size, &err); if (err) { return err; } GenericRequestHeader request_header{kOpcodeGetBufferData, info->buf_id, info->size}; memcpy(request_header.message, mr->GetDetails(), sizeof(fabric::MemoryRegionDetails)); - - auto currentMessageId = global_message_id_++; - client__->Send(address, currentMessageId, &request_header, sizeof(request_header), &err); - if (err) { - return err; - } - - /* The server is sending us the data over RDMA, and then sending us a confirmation */ - GenericNetworkResponse response{}; - client__->Recv(address, currentMessageId, &response, sizeof(response), &err); + + PerformNetworkTransfer(address, &request_header, &response, &err); if (err) { return err; } @@ -49,10 +44,12 @@ Error FabricClient::GetData(const FileInfo* info, FileData* data) { return TextError("Response NetworkErrorCode " + std::to_string(response.error_code)); } + data->swap(tempData); + return nullptr; } -fabric::FabricAddress FabricClient::GetAddressOrConnect(const FileInfo* info, Error* error) { +fabric::FabricAddress FabricConsumerClient::GetAddressOrConnect(const FileInfo* info, Error* error) { std::lock_guard<std::mutex> lock(mutex_); auto tableEntry = known_addresses_.find(info->source); @@ -67,3 +64,18 @@ fabric::FabricAddress FabricClient::GetAddressOrConnect(const FileInfo* info, Er return tableEntry->second; } } + +void FabricConsumerClient::PerformNetworkTransfer(fabric::FabricAddress address, + const GenericRequestHeader* request_header, + GenericNetworkResponse* response, Error* err) { + auto currentMessageId = global_message_id_++; + client__->Send(address, currentMessageId, request_header, sizeof(*request_header), err); + if (*err) { + return; + } + + /* The server is sending us the data over RDMA, and then sending us a confirmation */ + + client__->Recv(address, currentMessageId, response, sizeof(*response), err); + // if (*err) ... +} diff --git a/consumer/api/cpp/src/fabric_client.h b/consumer/api/cpp/src/fabric_consumer_client.h similarity index 75% rename from consumer/api/cpp/src/fabric_client.h rename to consumer/api/cpp/src/fabric_consumer_client.h index bb71ba42b4110c6b227f8ce3eb1d24b5ed64f228..1f98c591323c64b1213fe3bf4e7adea8b15e0f53 100644 --- a/consumer/api/cpp/src/fabric_client.h +++ b/consumer/api/cpp/src/fabric_consumer_client.h @@ -10,13 +10,12 @@ namespace asapo { -class FabricClient : NetClient { +class FabricConsumerClient : public NetClient { public: - explicit FabricClient(); + explicit FabricConsumerClient(); // modified in testings to mock system calls, otherwise do not touch std::unique_ptr<asapo::fabric::FabricFactory> factory__; - std::unique_ptr<IO> io__; std::unique_ptr<fabric::FabricClient> client__; private: @@ -29,6 +28,8 @@ class FabricClient : NetClient { private: fabric::FabricAddress GetAddressOrConnect(const FileInfo* info, Error* error); + void PerformNetworkTransfer(fabric::FabricAddress address, const GenericRequestHeader* request_header, + GenericNetworkResponse* response, Error* err); }; } diff --git a/consumer/api/cpp/unittests/test_fabric_consumer_client.cpp b/consumer/api/cpp/unittests/test_fabric_consumer_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0624c879749bed3830cb69d4277cfc661e1958ec --- /dev/null +++ b/consumer/api/cpp/unittests/test_fabric_consumer_client.cpp @@ -0,0 +1,299 @@ +#include <gtest/gtest.h> +#include <gmock/gmock.h> +#include <unittests/MockFabric.h> +#include <common/networking.h> +#include "../src/fabric_consumer_client.h" +#include "../../../../common/cpp/src/system_io/system_io.h" + +using namespace asapo; + +using ::testing::Test; +using ::testing::Ne; +using ::testing::Eq; +using ::testing::_; +using ::testing::SetArgPointee; +using ::testing::Return; +using ::testing::DoAll; +using ::testing::SaveArg; +using ::testing::SaveArgPointee; +using ::testing::StrictMock; +using ::testing::Expectation; + +TEST(FabricConsumerClient, Constructor) { + FabricConsumerClient client; + ASSERT_THAT(dynamic_cast<fabric::FabricFactory*>(client.factory__.get()), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<fabric::FabricClient*>(client.client__.get()), Eq(nullptr)); +} + +MATCHER_P6(M_CheckSendDataRequest, op_code, buf_id, data_size, mr_addr, mr_length, mr_key, + "Checks if a valid GenericRequestHeader was Send") { + auto data = (GenericRequestHeader*) arg; + auto mr = (fabric::MemoryRegionDetails*) &data->message; + return data->op_code == op_code + && data->data_id == uint64_t(buf_id) + && data->data_size == uint64_t(data_size) + && mr->addr == uint64_t(mr_addr) + && mr->length == uint64_t(mr_length) + && mr->key == uint64_t(mr_key); +} + +ACTION_P(A_WriteSendDataResponse, error_code) { + ((asapo::SendDataResponse*)arg2)->op_code = asapo::kOpcodeGetBufferData; + ((asapo::SendDataResponse*)arg2)->error_code = error_code; +} + +class FabricConsumerClientTests : public Test { + public: + FabricConsumerClient client; + StrictMock<fabric::MockFabricFactory> mock_fabric_factory; + StrictMock<fabric::MockFabricClient> mock_fabric_client; + + void SetUp() override { + client.factory__ = std::unique_ptr<fabric::FabricFactory> {&mock_fabric_factory}; + } + void TearDown() override { + client.factory__.release(); + client.client__.release(); + } + + public: + void ExpectInit(bool ok); + void ExpectAddedConnection(const std::string& address, bool ok, fabric::FabricAddress result); + void ExpectTransfer(void** outputData, fabric::FabricAddress serverAddr, + fabric::FabricMessageId messageId, bool sendOk, bool recvOk, + NetworkErrorCode serverResponse); +}; + +void FabricConsumerClientTests::ExpectInit(bool ok) { + EXPECT_CALL(mock_fabric_factory, CreateClient_t(_/*err*/)) + .WillOnce(DoAll( + SetArgPointee<0>(ok ? nullptr : fabric::FabricErrorTemplates::kInternalError.Generate().release()), + Return(&mock_fabric_client) + )); +} + +void FabricConsumerClientTests::ExpectAddedConnection(const std::string& address, bool ok, + fabric::FabricAddress result) { + EXPECT_CALL(mock_fabric_client, AddServerAddress_t(address, _/*err*/)) + .WillOnce(DoAll( + SetArgPointee<1>(ok ? nullptr : fabric::FabricErrorTemplates::kInternalError.Generate().release()), + Return(result) + )); +} + +void FabricConsumerClientTests::ExpectTransfer(void** outputData, fabric::FabricAddress serverAddr, + fabric::FabricMessageId messageId, bool sendOk, bool recvOk, + NetworkErrorCode serverResponse) { + static fabric::MemoryRegionDetails mrDetails{}; + mrDetails.addr = 0x124; + mrDetails.length = 4123; + mrDetails.key = 20; + + auto mr = new StrictMock<fabric::MockFabricMemoryRegion>(); + EXPECT_CALL(mock_fabric_client, ShareMemoryRegion_t(_, 4123, _/*err*/)).WillOnce(DoAll( + SaveArg<0>(outputData), + Return(mr) + )); + Expectation getDetailsCall = EXPECT_CALL(*mr, GetDetails()).WillOnce(Return(&mrDetails)); + + + Expectation sendCall = EXPECT_CALL(mock_fabric_client, Send_t(serverAddr, messageId, + M_CheckSendDataRequest(kOpcodeGetBufferData, 78954, 4123, 0x124, 4123, 20), + sizeof(GenericRequestHeader), _)).After(getDetailsCall) + .WillOnce(SetArgPointee<4>(sendOk ? nullptr : fabric::FabricErrorTemplates::kInternalError.Generate().release())); + + if (sendOk) { + Expectation recvCall = EXPECT_CALL(mock_fabric_client, Recv_t(serverAddr, messageId, _, + sizeof(GenericNetworkResponse), _)) + .After(sendCall) + .WillOnce(DoAll( + SetArgPointee<4>(recvOk ? nullptr : fabric::FabricErrorTemplates::kInternalError.Generate().release()), + A_WriteSendDataResponse(serverResponse) + )); + EXPECT_CALL(*mr, Destructor()).After(recvCall); + } else { + EXPECT_CALL(*mr, Destructor()).After(sendCall); + } + +} + +TEST_F(FabricConsumerClientTests, GetData_Error_Init) { + ExpectInit(false); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(fabric::FabricErrorTemplates::kInternalError)); +} + +TEST_F(FabricConsumerClientTests, GetData_Error_AddConnection) { + ExpectInit(true); + ExpectAddedConnection("host:1234", false, -1); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + Error err = client.GetData(&expectedInfo, &expectedFileData); + ASSERT_THAT(err, Eq(fabric::FabricErrorTemplates::kInternalError)); + + // Make sure that the connection was not saved + ExpectAddedConnection("host:1234", false, -1); + err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(fabric::FabricErrorTemplates::kInternalError)); +} + +TEST_F(FabricConsumerClientTests, GetData_ShareMemoryRegion_Error) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + + EXPECT_CALL(mock_fabric_client, ShareMemoryRegion_t(_, 4123, _/*err*/)) + .WillOnce(DoAll( + SetArgPointee<2>(fabric::FabricErrorTemplates::kInternalError.Generate().release()), + Return(nullptr) + )); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(fabric::FabricErrorTemplates::kInternalError)); +} + +TEST_F(FabricConsumerClientTests, GetData_SendFailed) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, false, false, kNetErrorNoError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Ne(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(nullptr)); +} + +TEST_F(FabricConsumerClientTests, GetData_RecvFailed) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, true, false, kNetErrorNoError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Ne(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(nullptr)); +} + +TEST_F(FabricConsumerClientTests, GetData_ServerError) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, true, true, kNetErrorInternalServerError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Ne(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(nullptr)); +} + +TEST_F(FabricConsumerClientTests, GetData_Ok) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, true, true, kNetErrorNoError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(outData)); +} + +TEST_F(FabricConsumerClientTests, GetData_Ok_UsedCahedConnection) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, true, true, kNetErrorNoError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(outData)); + + outData = nullptr; + ExpectTransfer(&outData, 0, 1, true, true, kNetErrorNoError); + + err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(outData)); +} + +TEST_F(FabricConsumerClientTests, GetData_Ok_SecondConnection) { + ExpectInit(true); + ExpectAddedConnection("host:1234", true, 0); + + FileData expectedFileData; + FileInfo expectedInfo{}; + expectedInfo.source = "host:1234"; + expectedInfo.size = 4123; + expectedInfo.buf_id = 78954; + + void* outData = nullptr; + ExpectTransfer(&outData, 0, 0, true, true, kNetErrorNoError); + + Error err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(outData)); + + ExpectAddedConnection("host:1235", true, 54); + expectedInfo.source = "host:1235"; + + outData = nullptr; + ExpectTransfer(&outData, 54, 1, true, true, kNetErrorNoError); + + err = client.GetData(&expectedInfo, &expectedFileData); + + ASSERT_THAT(err, Eq(nullptr)); + ASSERT_THAT(expectedFileData.get(), Eq(outData)); +} diff --git a/receiver/src/receiver_data_server/net_server/rds_fabric_server.h b/receiver/src/receiver_data_server/net_server/rds_fabric_server.h index 74e30dd39ab0394a4d0bd4b9e950deaf587c465b..04edcb0f2f4804e630441b691fea77568af7f4ed 100644 --- a/receiver/src/receiver_data_server/net_server/rds_fabric_server.h +++ b/receiver/src/receiver_data_server/net_server/rds_fabric_server.h @@ -12,7 +12,7 @@ class RdsFabricServer : public RdsNetServer { ~RdsFabricServer() override; // modified in testings to mock system calls, otherwise do not touch - std::unique_ptr<asapo::fabric::FabricFactory> factory__; + std::unique_ptr<fabric::FabricFactory> factory__; std::unique_ptr<IO> io__; const AbstractLogger* log__; std::unique_ptr<fabric::FabricServer> server__; diff --git a/receiver/unittests/receiver_data_server/net_server/test_rds_fabric_server.cpp b/receiver/unittests/receiver_data_server/net_server/test_rds_fabric_server.cpp index 699ed65d98e6d812a835ee9593ca15f4c05ed6f6..c7580823a0455a0def0760fda4f1e0761446f9b1 100644 --- a/receiver/unittests/receiver_data_server/net_server/test_rds_fabric_server.cpp +++ b/receiver/unittests/receiver_data_server/net_server/test_rds_fabric_server.cpp @@ -6,11 +6,13 @@ #include <unittests/MockFabric.h> #include "../../../src/receiver_data_server/net_server/rds_fabric_server.h" #include "../../../src/receiver_data_server/net_server/fabric_rds_request.h" +#include "../../../../common/cpp/src/system_io/system_io.h" using ::testing::Ne; using ::testing::Eq; using ::testing::Test; using ::testing::NiceMock; +using ::testing::StrictMock; using ::testing::DoAll; using ::testing::SetArgPointee; using ::testing::Return; @@ -22,18 +24,18 @@ std::string expected_address = "somehost:123"; TEST(RdsFabricServer, Constructor) { RdsFabricServer fabric_server(""); - ASSERT_THAT(dynamic_cast<asapo::IO*>(fabric_server.io__.get()), Ne(nullptr)); - ASSERT_THAT(dynamic_cast<const asapo::AbstractLogger*>(fabric_server.log__), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<SystemIO*>(fabric_server.io__.get()), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<fabric::FabricFactory*>(fabric_server.factory__.get()), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<const AbstractLogger*>(fabric_server.log__), Ne(nullptr)); } - class RdsFabricServerTests : public Test { public: RdsFabricServer rds_server{expected_address}; NiceMock<MockLogger> mock_logger; - MockIO mock_io; - fabric::MockFabricFactory mock_fabric_factory; - fabric::MockFabricServer mock_fabric_server; + StrictMock<MockIO> mock_io; + StrictMock<fabric::MockFabricFactory> mock_fabric_factory; + StrictMock<fabric::MockFabricServer> mock_fabric_server; void SetUp() override { rds_server.log__ = &mock_logger; @@ -137,7 +139,6 @@ TEST_F(RdsFabricServerTests, GetNewRequests_Ok) { ASSERT_THAT(req->GetMemoryRegion()->key, Eq(23)); } - TEST_F(RdsFabricServerTests, GetNewRequests_Error_RecvAny_InternalError) { InitServer();