From 3e2a612fb3ce235b80d5ab4321a6db10a196174f Mon Sep 17 00:00:00 2001
From: Sergey Yakubov <sergey.yakubov@desy.de>
Date: Mon, 3 May 2021 14:52:36 +0200
Subject: [PATCH] fix segfault with callback as class member in
 send_stream_finished_flag

---
 common/cpp/include/asapo/request/request.h    |   3 +
 common/cpp/src/request/request_pool.cpp       |  21 +++-
 .../unittests/request/test_request_pool.cpp   | 106 ++++++++++--------
 producer/api/cpp/src/producer_request.h       |   3 +
 producer/api/python/asapo_producer.pyx.in     |   2 +
 .../producer/python_api/check_linux.sh        |   8 +-
 .../producer/python_api/check_windows.bat     |   6 +-
 .../producer/python_api/producer_api.py       |  15 ++-
 8 files changed, 104 insertions(+), 60 deletions(-)

diff --git a/common/cpp/include/asapo/request/request.h b/common/cpp/include/asapo/request/request.h
index 594c03c85..83774bc6e 100644
--- a/common/cpp/include/asapo/request/request.h
+++ b/common/cpp/include/asapo/request/request.h
@@ -15,6 +15,9 @@ class GenericRequest {
     GenericRequest(GenericRequestHeader h, uint64_t timeout_ms): header{std::move(h)}, timeout_ms_{timeout_ms} {};
     GenericRequestHeader header;
     virtual ~GenericRequest() = default;
+    virtual bool ContainsData() {
+        return true;
+    };
     uint64_t GetRetryCounter() {
         return retry_counter_;
     }
diff --git a/common/cpp/src/request/request_pool.cpp b/common/cpp/src/request/request_pool.cpp
index 5df251934..dc2573b67 100644
--- a/common/cpp/src/request/request_pool.cpp
+++ b/common/cpp/src/request/request_pool.cpp
@@ -28,10 +28,11 @@ Error RequestPool::CanAddRequests(const GenericRequests &requests) {
 
     uint64_t total_size = 0;
     for (auto &request : requests) {
-        total_size += request->header.data_size;
+        if (request->ContainsData()) {
+            total_size += request->header.data_size;
+        }
     }
 
-
     if (memory_used_ + total_size > limits_.max_memory_mb * 1000000) {
         return IOErrorTemplates::kNoSpaceLeft.Generate(
             "reached maximum memory capacity of " + std::to_string(limits_.max_memory_mb) + " MB");
@@ -49,6 +50,10 @@ Error RequestPool::CanAddRequest(const GenericRequestPtr &request, bool top_prio
             "reached maximum number of " + std::to_string(limits_.max_requests) + " requests");
     }
 
+    if (!request->ContainsData()) {
+        return nullptr;
+    }
+
     if (limits_.max_memory_mb > 0 && memory_used_ + request->header.data_size > limits_.max_memory_mb * 1000000) {
         return IOErrorTemplates::kNoSpaceLeft.Generate(
             "reached maximum memory capacity of " + std::to_string(limits_.max_memory_mb) + " MB");
@@ -66,7 +71,9 @@ Error RequestPool::AddRequest(GenericRequestPtr request, bool top_priority) {
         return err;
     }
 
-    memory_used_ += request->header.data_size;
+    if (request->ContainsData()) {
+        memory_used_ += request->header.data_size;
+    }
 
     if (top_priority) {
         request_queue_.emplace_front(std::move(request));
@@ -118,7 +125,9 @@ void RequestPool::ProcessRequest(const std::unique_ptr<RequestHandler> &request_
         std::this_thread::sleep_for(std::chrono::milliseconds(1000));
         thread_info->lock.lock();
     } else {
-        memory_used_ -= request->header.data_size;
+        if (request->ContainsData()) {
+            memory_used_ -= request->header.data_size;
+        }
     }
 }
 
@@ -157,7 +166,9 @@ Error RequestPool::AddRequests(GenericRequests requests) {
 
     uint64_t total_size = 0;
     for (auto &elem : requests) {
-        total_size += elem->header.data_size;
+        if (elem->ContainsData()) {
+            total_size += elem->header.data_size;
+        }
         request_queue_.emplace_front(std::move(elem));
     }
     memory_used_ += total_size;
diff --git a/common/cpp/unittests/request/test_request_pool.cpp b/common/cpp/unittests/request/test_request_pool.cpp
index ccda29729..44bfbb8c4 100644
--- a/common/cpp/unittests/request/test_request_pool.cpp
+++ b/common/cpp/unittests/request/test_request_pool.cpp
@@ -37,42 +37,45 @@ using asapo::ErrorInterface;
 using asapo::GenericRequest;
 using asapo::GenericRequestHeader;
 
-
-
 class MockRequestHandlerFactory : public asapo::RequestHandlerFactory {
-  public:
-    MockRequestHandlerFactory(RequestHandler* request_handler):
-        RequestHandlerFactory() {
-        request_handler_ = request_handler;
-    }
-    std::unique_ptr<RequestHandler> NewRequestHandler(uint64_t thread_id, uint64_t* shared_counter) override {
-        return std::unique_ptr<RequestHandler> {request_handler_};
-    }
-  private:
-    RequestHandler* request_handler_;
+ public:
+  MockRequestHandlerFactory(RequestHandler* request_handler) :
+      RequestHandlerFactory() {
+      request_handler_ = request_handler;
+  }
+  std::unique_ptr<RequestHandler> NewRequestHandler(uint64_t thread_id, uint64_t* shared_counter) override {
+      return std::unique_ptr<RequestHandler>{request_handler_};
+  }
+ private:
+  RequestHandler* request_handler_;
 };
 
 class TestRequest : public GenericRequest {
-  public:
-    TestRequest(GenericRequestHeader header, uint64_t timeout): GenericRequest(header, timeout) {};
+ public:
+  TestRequest(GenericRequestHeader header, uint64_t timeout, bool contains_data = true) : GenericRequest(header,
+                                                                                                         timeout) {
+      contains_data_ = contains_data;
+  };
+  bool ContainsData() { return contains_data_; };
+ private:
+  bool contains_data_;
 };
 
-
 class RequestPoolTests : public testing::Test {
-  public:
-    NiceMock<MockRequestHandler>* mock_request_handler = new testing::NiceMock<MockRequestHandler>;
-    NiceMock<asapo::MockLogger> mock_logger;
-    MockRequestHandlerFactory request_handler_factory{mock_request_handler};
-    const uint8_t nthreads = 1;
-    asapo::RequestPool pool {nthreads, &request_handler_factory, &mock_logger};
-    std::unique_ptr<GenericRequest> request{new TestRequest{GenericRequestHeader{asapo::kOpcodeUnknownOp,0,1000000}, 0}};
-    void SetUp() override {
-    }
-    void TearDown() override {
-    }
+ public:
+  NiceMock<MockRequestHandler>* mock_request_handler = new testing::NiceMock<MockRequestHandler>;
+  NiceMock<asapo::MockLogger> mock_logger;
+  MockRequestHandlerFactory request_handler_factory{mock_request_handler};
+  const uint8_t nthreads = 1;
+  asapo::RequestPool pool{nthreads, &request_handler_factory, &mock_logger};
+  std::unique_ptr<GenericRequest>
+      request{new TestRequest{GenericRequestHeader{asapo::kOpcodeUnknownOp, 0, 1000000}, 0}};
+  void SetUp() override {
+  }
+  void TearDown() override {
+  }
 };
 
-
 TEST(RequestPool, Constructor) {
     NiceMock<asapo::MockLogger> mock_logger;
     MockRequestHandlerFactory factory(nullptr);
@@ -118,9 +121,9 @@ void ExpectSend(MockRequestHandler* mock_handler, int ntimes = 1) {
     EXPECT_CALL(*mock_handler, ReadyProcessRequest()).Times(ntimes).WillRepeatedly(Return(true));
     EXPECT_CALL(*mock_handler, PrepareProcessingRequestLocked()).Times(ntimes);
     EXPECT_CALL(*mock_handler, ProcessRequestUnlocked_t(_, _)).Times(ntimes).WillRepeatedly(
-        DoAll(            testing::SetArgPointee<1>(false),
-                          Return(true)
-             ));
+        DoAll(testing::SetArgPointee<1>(false),
+              Return(true)
+        ));
     EXPECT_CALL(*mock_handler, TearDownProcessingRequestLocked(true)).Times(ntimes);
 }
 
@@ -128,14 +131,12 @@ void ExpectFailProcessRequest(MockRequestHandler* mock_handler) {
     EXPECT_CALL(*mock_handler, ReadyProcessRequest()).Times(AtLeast(1)).WillRepeatedly(Return(true));
     EXPECT_CALL(*mock_handler, PrepareProcessingRequestLocked()).Times(AtLeast(1));
     EXPECT_CALL(*mock_handler, ProcessRequestUnlocked_t(_, _)).Times(AtLeast(1)).WillRepeatedly(
-        DoAll(            testing::SetArgPointee<1>(true),
-                          Return(false)
-             ));
+        DoAll(testing::SetArgPointee<1>(true),
+              Return(false)
+        ));
     EXPECT_CALL(*mock_handler, TearDownProcessingRequestLocked(false)).Times(AtLeast(1));
 }
 
-
-
 TEST_F(RequestPoolTests, AddRequestIncreasesRetryCounter) {
 
     ExpectFailProcessRequest(mock_request_handler);
@@ -147,7 +148,6 @@ TEST_F(RequestPoolTests, AddRequestIncreasesRetryCounter) {
     ASSERT_THAT(mock_request_handler->retry_counter, Gt(0));
 }
 
-
 TEST_F(RequestPoolTests, AddRequestCallsSend) {
 
     ExpectSend(mock_request_handler);
@@ -182,7 +182,6 @@ TEST_F(RequestPoolTests, NRequestsInPoolAccountsForRequestsInProgress) {
     ASSERT_THAT(nreq2, Eq(0));
 }
 
-
 TEST_F(RequestPoolTests, AddRequestCallsSendTwoRequests) {
 
     TestRequest* request2 = new TestRequest{GenericRequestHeader{}, 0};
@@ -201,7 +200,7 @@ TEST_F(RequestPoolTests, AddRequestCallsSendTwoRequests) {
 TEST_F(RequestPoolTests, RefuseAddRequestIfHitSizeLimitation) {
     TestRequest* request2 = new TestRequest{GenericRequestHeader{}, 0};
 
-    pool.SetLimits(asapo::RequestPoolLimits({1,0}));
+    pool.SetLimits(asapo::RequestPoolLimits({1, 0}));
     pool.AddRequest(std::move(request));
     request.reset(request2);
     auto err = pool.AddRequest(std::move(request));
@@ -219,7 +218,7 @@ TEST_F(RequestPoolTests, RefuseAddRequestIfHitMemoryLimitation) {
     header.data_size = 100;
     TestRequest* request2 = new TestRequest{header, 0};
 
-    pool.SetLimits(asapo::RequestPoolLimits({0,1}));
+    pool.SetLimits(asapo::RequestPoolLimits({0, 1}));
     pool.AddRequest(std::move(request));
     request.reset(request2);
     auto err = pool.AddRequest(std::move(request));
@@ -233,15 +232,31 @@ TEST_F(RequestPoolTests, RefuseAddRequestIfHitMemoryLimitation) {
 
 }
 
+TEST_F(RequestPoolTests, OkAddRequestIfSendingFile) {
+    auto header = GenericRequestHeader{};
+    header.data_size = 100;
+    TestRequest* request2 = new TestRequest{header, 0, false};
+
+    pool.SetLimits(asapo::RequestPoolLimits({0, 1}));
+    pool.AddRequest(std::move(request));
+    request.reset(request2);
+    auto err = pool.AddRequest(std::move(request));
+
+    auto nreq = pool.NRequestsInPool();
+
+    ASSERT_THAT(nreq, Eq(2));
+    ASSERT_THAT(err, Eq(nullptr));
+}
+
 TEST_F(RequestPoolTests, RefuseAddRequestsIfHitSizeLimitation) {
 
     TestRequest* request2 = new TestRequest{GenericRequestHeader{}, 0};
 
     std::vector<std::unique_ptr<GenericRequest>> requests;
     requests.push_back(std::move(request));
-    requests.push_back(std::unique_ptr<GenericRequest> {request2});
+    requests.push_back(std::unique_ptr<GenericRequest>{request2});
 
-    pool.SetLimits(asapo::RequestPoolLimits({1,0}));
+    pool.SetLimits(asapo::RequestPoolLimits({1, 0}));
     auto err = pool.AddRequests(std::move(requests));
     auto nreq = pool.NRequestsInPool();
 
@@ -249,7 +264,6 @@ TEST_F(RequestPoolTests, RefuseAddRequestsIfHitSizeLimitation) {
     ASSERT_THAT(err, Eq(asapo::IOErrorTemplates::kNoSpaceLeft));
 }
 
-
 TEST_F(RequestPoolTests, RefuseAddRequestsIfHitMemoryLimitation) {
 
     auto header = GenericRequestHeader{};
@@ -259,9 +273,9 @@ TEST_F(RequestPoolTests, RefuseAddRequestsIfHitMemoryLimitation) {
 
     std::vector<std::unique_ptr<GenericRequest>> requests;
     requests.push_back(std::move(request));
-    requests.push_back(std::unique_ptr<GenericRequest> {request2});
+    requests.push_back(std::unique_ptr<GenericRequest>{request2});
 
-    pool.SetLimits(asapo::RequestPoolLimits({0,1}));
+    pool.SetLimits(asapo::RequestPoolLimits({0, 1}));
     auto err = pool.AddRequests(std::move(requests));
     auto nreq = pool.NRequestsInPool();
 
@@ -269,7 +283,6 @@ TEST_F(RequestPoolTests, RefuseAddRequestsIfHitMemoryLimitation) {
     ASSERT_THAT(err, Eq(asapo::IOErrorTemplates::kNoSpaceLeft));
 }
 
-
 TEST_F(RequestPoolTests, AddRequestsOk) {
 
     TestRequest* request2 = new TestRequest{GenericRequestHeader{}, 0};
@@ -278,7 +291,7 @@ TEST_F(RequestPoolTests, AddRequestsOk) {
 
     std::vector<std::unique_ptr<GenericRequest>> requests;
     requests.push_back(std::move(request));
-    requests.push_back(std::unique_ptr<GenericRequest> {request2});
+    requests.push_back(std::unique_ptr<GenericRequest>{request2});
 
     auto err = pool.AddRequests(std::move(requests));
 
@@ -328,5 +341,4 @@ TEST_F(RequestPoolTests, StopThreads) {
     Mock::VerifyAndClearExpectations(&mock_logger);
 }
 
-
 }
diff --git a/producer/api/cpp/src/producer_request.h b/producer/api/cpp/src/producer_request.h
index 4d0b73fc1..f15d25103 100644
--- a/producer/api/cpp/src/producer_request.h
+++ b/producer/api/cpp/src/producer_request.h
@@ -18,6 +18,9 @@ class ProducerRequest : public GenericRequest {
                     RequestCallback callback,
                     bool manage_data_memory,
                     uint64_t timeout_ms);
+    virtual bool ContainsData() override {
+      return !DataFromFile();
+    };
     std::string source_credentials;
     std::string metadata;
     MessageData data;
diff --git a/producer/api/python/asapo_producer.pyx.in b/producer/api/python/asapo_producer.pyx.in
index 5391c8ce8..3b3609cc4 100644
--- a/producer/api/python/asapo_producer.pyx.in
+++ b/producer/api/python/asapo_producer.pyx.in
@@ -220,6 +220,8 @@ cdef class PyProducer:
         unwrap_callback(<RequestCallbackCython>self.c_callback, <void*>self,<void*>callback if callback != None else NULL))
         if err:
             throw_exception(err)
+        if callback != None:
+            Py_XINCREF(<PyObject*>callback)
 
     def stream_info(self, stream = 'default', uint64_t timeout_ms = 1000):
         """
diff --git a/tests/automatic/producer/python_api/check_linux.sh b/tests/automatic/producer/python_api/check_linux.sh
index 41564e5e5..957e6ff0d 100644
--- a/tests/automatic/producer/python_api/check_linux.sh
+++ b/tests/automatic/producer/python_api/check_linux.sh
@@ -41,12 +41,12 @@ sleep 10
 
 $1 $3 $data_source $beamtime_id  "127.0.0.1:8400" &> out || cat out
 cat out
-echo count successfully send, expect 13
-cat out | grep "successfuly sent" | wc -l | tee /dev/stderr | grep 13
+echo count successfully send, expect 15
+cat out | grep "successfuly sent" | wc -l | tee /dev/stderr | grep 15
 echo count same id, expect 4
 cat out | grep "already have record with same id" | wc -l | tee /dev/stderr | grep 4
-echo count duplicates, expect 4
-cat out | grep "duplicate" | wc -l | tee /dev/stderr | grep 4
+echo count duplicates, expect 6
+cat out | grep "duplicate" | wc -l | tee /dev/stderr | grep 6
 echo count data in callback, expect 3
 cat out | grep "'data':" | wc -l  | tee /dev/stderr | grep 3
 echo check found local io error
diff --git a/tests/automatic/producer/python_api/check_windows.bat b/tests/automatic/producer/python_api/check_windows.bat
index c115da36f..5d679b0a5 100644
--- a/tests/automatic/producer/python_api/check_windows.bat
+++ b/tests/automatic/producer/python_api/check_windows.bat
@@ -22,16 +22,16 @@ set PYTHONPATH=%2
 type out
 set NUM=0
 for /F %%N in ('find /C "successfuly sent" ^< "out"') do set NUM=%%N
-echo %NUM% | findstr 13 || goto error
+echo %NUM% | findstr 15 || goto error
 
 for /F %%N in ('find /C "} wrong input: Bad request: already have record with same id" ^< "out"') do set NUM=%%N
 echo %NUM% | findstr 2 || goto error
 
 for /F %%N in ('find /C "} server warning: ignoring duplicate record" ^< "out"') do set NUM=%%N
-echo %NUM% | findstr 1 || goto error
+echo %NUM% | findstr 2 || goto error
 
 for /F %%N in ('find /C "} server warning: duplicated request" ^< "out"') do set NUM=%%N
-echo %NUM% | findstr 1 || goto error
+echo %NUM% | findstr 2 || goto error
 
 
 findstr /I /L /C:"Finished successfully" out || goto :error
diff --git a/tests/automatic/producer/python_api/producer_api.py b/tests/automatic/producer/python_api/producer_api.py
index 2b4198650..879053cd9 100644
--- a/tests/automatic/producer/python_api/producer_api.py
+++ b/tests/automatic/producer/python_api/producer_api.py
@@ -25,6 +25,11 @@ def assert_eq(val, expected, name):
         print('val: ', val, ' expected: ', expected)
         sys.exit(1)
 
+class CallBackClass:
+    def callback(self, payload, err):
+        callback(payload,err)
+
+callback_object = CallBackClass()
 
 def callback(payload, err):
     lock.acquire()  # to print
@@ -159,6 +164,13 @@ else:
     print("should be AsapoRequestsPoolIsFull error ")
     sys.exit(1)
 
+#stream_finished
+producer.wait_requests_finished(10000)
+producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback)
+# check callback_object.callback works, will be duplicated request
+producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback_object.callback)
+producer.wait_requests_finished(10000)
+
 
 #stream infos
 info = producer.stream_info()
@@ -172,7 +184,8 @@ print("created: ",datetime.utcfromtimestamp(info['timestampCreated']/1000000000)
 print("last record: ",datetime.utcfromtimestamp(info['timestampLast']/1000000000).strftime('%Y-%m-%d %H:%M:%S.%f'))
 
 info = producer.stream_info('stream')
-assert_eq(info['lastId'], 2, "last id from different stream")
+assert_eq(info['lastId'], 3, "last id from different stream")
+assert_eq(info['finished'], True, "stream finished")
 
 info_last = producer.last_stream()
 assert_eq(info_last['name'], "stream", "last stream")
-- 
GitLab