From b4d89302e4f1999603d1a312d087302e65248c7d Mon Sep 17 00:00:00 2001
From: Sergey Yakubov <sergey.yakubov@desy.de>
Date: Wed, 24 Feb 2021 11:32:10 +0100
Subject: [PATCH] implement get_current_dataset_count

---
 CHANGELOG.md                                  |  3 +-
 broker/src/asapo_broker/database/mongodb.go   | 11 +++++-
 .../src/asapo_broker/database/mongodb_test.go | 32 +++++++++++++++
 .../asapo_broker/server/get_commands_test.go  |  4 +-
 broker/src/asapo_broker/server/get_size.go    |  4 +-
 .../api/cpp/include/asapo/consumer/consumer.h | 22 ++++++++---
 consumer/api/cpp/src/consumer_impl.cpp        | 39 ++++++++++++-------
 consumer/api/cpp/src/consumer_impl.h          |  8 +++-
 .../api/cpp/unittests/test_consumer_impl.cpp  | 17 ++++++++
 consumer/api/python/asapo_consumer.pxd        |  1 +
 consumer/api/python/asapo_consumer.pyx.in     | 10 +++++
 .../consumer/consumer_api/consumer_api.cpp    | 13 +++++++
 .../consumer_api_python/consumer_api.py       | 19 +++++++++
 13 files changed, 156 insertions(+), 27 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index a42a03c08..45154555a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,8 @@
-## 20.12.1 (in progress)
+## 21.03.0 (in progress)
 
  IMPROVEMENTS
 * Producer API - queue limits in Python, for C++ return original data in error custom data
+* Consumer API - add GetCurrentDatasetCount/get_current_dataset_count function with option to include or exclude incomplete datasets
 
 ## 20.12.0
 
diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go
index 6683be24c..09b17e854 100644
--- a/broker/src/asapo_broker/database/mongodb.go
+++ b/broker/src/asapo_broker/database/mongodb.go
@@ -567,9 +567,18 @@ func (db *Mongodb) getSize(request Request) ([]byte, error) {
 	c := db.client.Database(request.DbName).Collection(data_collection_name_prefix + request.DbCollectionName)
 	var rec SizeRecord
 	var err error
+	filter:=bson.M{}
+	if request.ExtraParam=="false" { // do not return incomplete datasets
+		filter = bson.M{"$expr": bson.M{"$eq": []interface{}{"$size", bson.M{"$size": "$messages"}}}}
+	} else if request.ExtraParam=="true" {
+		filter = bson.M{"$expr": bson.M{"gt": []interface{}{0, bson.M{"$size": "$messages"}}}}
+	}
 
-	size, err := c.CountDocuments(context.TODO(), bson.M{}, options.Count())
+	size, err := c.CountDocuments(context.TODO(), filter, options.Count())
 	if err != nil {
+		if ce, ok := err.(mongo.CommandError); ok && ce.Code == 17124 {
+			return nil,&DBError{utils.StatusWrongInput, "no datasets found"}
+		}
 		return nil, err
 	}
 	rec.Size = int(size)
diff --git a/broker/src/asapo_broker/database/mongodb_test.go b/broker/src/asapo_broker/database/mongodb_test.go
index dbf379375..cc6119e17 100644
--- a/broker/src/asapo_broker/database/mongodb_test.go
+++ b/broker/src/asapo_broker/database/mongodb_test.go
@@ -392,6 +392,38 @@ func TestMongoDBGetSize(t *testing.T) {
 	assert.Equal(t, string(recs1_expect), string(res))
 }
 
+func TestMongoDBGetSizeForDatasets(t *testing.T) {
+	db.Connect(dbaddress)
+	defer cleanup()
+	db.insertRecord(dbname, collection, &rec1)
+
+	_, err := db.ProcessRequest(Request{DbName: dbname, DbCollectionName: collection, Op: "size",ExtraParam: "false"})
+	assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code)
+
+	_, err1 := db.ProcessRequest(Request{DbName: dbname, DbCollectionName: collection, Op: "size",ExtraParam: "true"})
+	assert.Equal(t, utils.StatusWrongInput, err1.(*DBError).Code)
+}
+
+func TestMongoDBGetSizeDataset(t *testing.T) {
+	db.Connect(dbaddress)
+	defer cleanup()
+
+	db.insertRecord(dbname, collection, &rec_dataset1)
+	db.insertRecord(dbname, collection, &rec_dataset2_incomplete)
+
+	size2_expect, _ := json.Marshal(SizeRecord{2})
+	size1_expect, _ := json.Marshal(SizeRecord{1})
+
+	res, err := db.ProcessRequest(Request{DbName: dbname, DbCollectionName: collection, Op: "size",ExtraParam: "true"})
+	assert.Nil(t, err)
+	assert.Equal(t, string(size2_expect), string(res))
+
+	res, err = db.ProcessRequest(Request{DbName: dbname, DbCollectionName: collection, Op: "size",ExtraParam: "false"})
+	assert.Nil(t, err)
+	assert.Equal(t, string(size1_expect), string(res))
+
+}
+
 func TestMongoDBGetSizeNoRecords(t *testing.T) {
 	db.Connect(dbaddress)
 	defer cleanup()
diff --git a/broker/src/asapo_broker/server/get_commands_test.go b/broker/src/asapo_broker/server/get_commands_test.go
index e4db0514b..9870e6d2d 100644
--- a/broker/src/asapo_broker/server/get_commands_test.go
+++ b/broker/src/asapo_broker/server/get_commands_test.go
@@ -47,10 +47,10 @@ var testsGetCommand = []struct {
 	{"next", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","",""},
 	{"next", expectedStream, expectedGroupID, expectedStream + "/" +
 		expectedGroupID + "/next","&resend_nacks=true&delay_ms=10000&resend_attempts=3","10000_3"},
-	{"size", expectedStream, "", expectedStream  + "/size","","0"},
+	{"size", expectedStream, "", expectedStream  + "/size","",""},
+	{"size", expectedStream, "", expectedStream  + "/size","&incomplete=true","true"},
 	{"streams", "0", "", "0/streams","",""},
 	{"lastack", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/lastack","",""},
-
 }
 
 
diff --git a/broker/src/asapo_broker/server/get_size.go b/broker/src/asapo_broker/server/get_size.go
index fa4dd2367..1355e955b 100644
--- a/broker/src/asapo_broker/server/get_size.go
+++ b/broker/src/asapo_broker/server/get_size.go
@@ -5,5 +5,7 @@ import (
 )
 
 func routeGetSize(w http.ResponseWriter, r *http.Request) {
-	processRequest(w, r, "size", "0", false)
+	keys := r.URL.Query()
+	incomplete := keys.Get("incomplete")
+	processRequest(w, r, "size", incomplete, false)
 }
diff --git a/consumer/api/cpp/include/asapo/consumer/consumer.h b/consumer/api/cpp/include/asapo/consumer/consumer.h
index 678c7179b..02dcb68e5 100644
--- a/consumer/api/cpp/include/asapo/consumer/consumer.h
+++ b/consumer/api/cpp/include/asapo/consumer/consumer.h
@@ -76,7 +76,7 @@ class Consumer {
     //! Get list of streams, set from to "" to get all streams
     virtual StreamInfos GetStreamList(std::string from, Error* err) = 0;
 
-    //! Get current number of datasets
+    //! Get current number of messages in stream
     /*!
       \param stream - stream to use
       \param err - return nullptr of operation succeed, error otherwise.
@@ -84,11 +84,21 @@ class Consumer {
     */
     virtual uint64_t GetCurrentSize(std::string stream, Error* err) = 0;
 
-    //! Generate new GroupID.
-    /*!
-      \param err - return nullptr of operation succeed, error otherwise.
-      \return group ID.
-    */
+  //! Get current number of datasets in stream
+  /*!
+    \param stream - stream to use
+    \param include_incomplete - flag to count incomplete datasets as well
+    \param err - return nullptr of operation succeed, error otherwise.
+    \return number of datasets.
+  */
+    virtual uint64_t GetCurrentDatasetCount(std::string stream, bool include_incomplete, Error* err) = 0;
+
+  //! Generate new GroupID.
+  /*!
+    \param err - return nullptr of operation succeed, error otherwise.
+    \return group ID.
+  */
+
     virtual std::string GenerateNewGroupId(Error* err) = 0;
 
     //! Get Beamtime metadata.
diff --git a/consumer/api/cpp/src/consumer_impl.cpp b/consumer/api/cpp/src/consumer_impl.cpp
index 7a94ecf6e..4386dd9b1 100644
--- a/consumer/api/cpp/src/consumer_impl.cpp
+++ b/consumer/api/cpp/src/consumer_impl.cpp
@@ -529,20 +529,7 @@ Error ConsumerImpl::SetLastReadMarker(std::string group_id, uint64_t value, std:
 }
 
 uint64_t ConsumerImpl::GetCurrentSize(std::string stream, Error* err) {
-    RequestInfo ri;
-    ri.api = "/database/" + source_credentials_.beamtime_id + "/" + source_credentials_.data_source +
-        +"/" + std::move(stream) + "/size";
-    auto responce = BrokerRequestWithTimeout(ri, err);
-    if (*err) {
-        return 0;
-    }
-
-    JsonStringParser parser(responce);
-    uint64_t size;
-    if ((*err = parser.GetUInt64("size", &size)) != nullptr) {
-        return 0;
-    }
-    return size;
+    return GetCurrentCount(stream,false,false,err);
 }
 
 Error ConsumerImpl::GetById(uint64_t id, MessageMeta* info, MessageData* data, std::string stream) {
@@ -838,4 +825,28 @@ void ConsumerImpl::InterruptCurrentOperation() {
     interrupt_flag_= true;
 }
 
+uint64_t ConsumerImpl::GetCurrentDatasetCount(std::string stream, bool include_incomplete, Error* err) {
+    return GetCurrentCount(stream,true,include_incomplete,err);
+}
+
+uint64_t ConsumerImpl::GetCurrentCount(std::string stream, bool datasets, bool include_incomplete, Error* err) {
+    RequestInfo ri;
+    ri.api = "/database/" + source_credentials_.beamtime_id + "/" + source_credentials_.data_source +
+        +"/" + std::move(stream) + "/size";
+    if (datasets) {
+        ri.extra_params = std::string("&incomplete=")+(include_incomplete?"true":"false");
+    }
+    auto responce = BrokerRequestWithTimeout(ri, err);
+    if (*err) {
+        return 0;
+    }
+
+    JsonStringParser parser(responce);
+    uint64_t size;
+    if ((*err = parser.GetUInt64("size", &size)) != nullptr) {
+        return 0;
+    }
+    return size;
+}
+
 }
diff --git a/consumer/api/cpp/src/consumer_impl.h b/consumer/api/cpp/src/consumer_impl.h
index 0697b5f96..27488d22d 100644
--- a/consumer/api/cpp/src/consumer_impl.h
+++ b/consumer/api/cpp/src/consumer_impl.h
@@ -80,6 +80,7 @@ class ConsumerImpl final : public asapo::Consumer {
     std::string GetBeamtimeMeta(Error* err) override;
 
     uint64_t GetCurrentSize(std::string stream, Error* err) override;
+    uint64_t GetCurrentDatasetCount(std::string stream, bool include_incomplete, Error* err) override;
 
     Error GetById(uint64_t id, MessageMeta* info, MessageData* data, std::string stream) override;
 
@@ -138,11 +139,14 @@ class ConsumerImpl final : public asapo::Consumer {
     Error FtsSizeRequestWithTimeout(MessageMeta* info);
     Error ProcessPostRequest(const RequestInfo& request, RequestOutput* response, HttpCode* code);
     Error ProcessGetRequest(const RequestInfo& request, RequestOutput* response, HttpCode* code);
-
     RequestInfo PrepareRequestInfo(std::string api_url, bool dataset, uint64_t min_size);
     std::string OpToUriCmd(GetMessageServerOperation op);
     Error UpdateFolderTokenIfNeeded(bool ignore_existing);
-    std::string endpoint_;
+
+    uint64_t GetCurrentCount(std::string stream, bool datasets, bool include_incomplete, Error* err);
+
+
+      std::string endpoint_;
     std::string current_broker_uri_;
     std::string current_fts_uri_;
     std::string source_path_;
diff --git a/consumer/api/cpp/unittests/test_consumer_impl.cpp b/consumer/api/cpp/unittests/test_consumer_impl.cpp
index 26020e46e..c252a3ec8 100644
--- a/consumer/api/cpp/unittests/test_consumer_impl.cpp
+++ b/consumer/api/cpp/unittests/test_consumer_impl.cpp
@@ -1307,4 +1307,21 @@ TEST_F(ConsumerImplTests, CanInterruptOperation) {
 
 }
 
+
+TEST_F(ConsumerImplTests, GetCurrentDataSetCounteUsesCorrectUri) {
+    MockGetBrokerUri();
+    consumer->SetTimeout(100);
+
+    EXPECT_CALL(mock_http_client, Get_t(expected_broker_uri + "/database/beamtime_id/" + expected_data_source + "/" +
+        expected_stream + "/size?token="
+                                            + expected_token+"&incomplete=true", _, _)).WillOnce(DoAll(
+        SetArgPointee<1>(HttpCode::OK),
+        SetArgPointee<2>(nullptr),
+        Return("{\"size\":10}")));
+    asapo::Error err;
+    auto size = consumer->GetCurrentDatasetCount(expected_stream,true, &err);
+    ASSERT_THAT(err, Eq(nullptr));
+    ASSERT_THAT(size, Eq(10));
+}
+
 }
diff --git a/consumer/api/python/asapo_consumer.pxd b/consumer/api/python/asapo_consumer.pxd
index cca41075d..cea9988b7 100644
--- a/consumer/api/python/asapo_consumer.pxd
+++ b/consumer/api/python/asapo_consumer.pxd
@@ -66,6 +66,7 @@ cdef extern from "asapo/asapo_consumer.h" namespace "asapo" nogil:
         Error GetLast(MessageMeta* info, MessageData* data, string stream)
         Error GetById(uint64_t id, MessageMeta* info, MessageData* data, string stream)
         uint64_t GetCurrentSize(string stream, Error* err)
+        uint64_t GetCurrentDatasetCount(string stream, bool include_incomplete, Error* err)
         Error SetLastReadMarker(string group_id, uint64_t value, string stream)
         Error ResetLastReadMarker(string group_id, string stream)
         Error Acknowledge(string group_id, uint64_t id, string stream)
diff --git a/consumer/api/python/asapo_consumer.pyx.in b/consumer/api/python/asapo_consumer.pyx.in
index 399b510ed..9e2010040 100644
--- a/consumer/api/python/asapo_consumer.pyx.in
+++ b/consumer/api/python/asapo_consumer.pyx.in
@@ -167,6 +167,16 @@ cdef class PyConsumer:
         if err:
             throw_exception(err)
         return size
+    def get_current_dataset_count(self, stream = "default", bool include_incomplete = False):
+        cdef Error err
+        cdef uint64_t size
+        cdef string b_stream = _bytes(stream)
+        with nogil:
+            size =  self.c_consumer.get().GetCurrentDatasetCount(b_stream,include_incomplete,&err)
+        err_str = _str(GetErrorString(&err))
+        if err:
+            throw_exception(err)
+        return size
     def set_timeout(self,timeout):
         self.c_consumer.get().SetTimeout(timeout)
     def force_no_rdma(self):
diff --git a/tests/automatic/consumer/consumer_api/consumer_api.cpp b/tests/automatic/consumer/consumer_api/consumer_api.cpp
index e451289ed..61f5b38ad 100644
--- a/tests/automatic/consumer/consumer_api/consumer_api.cpp
+++ b/tests/automatic/consumer/consumer_api/consumer_api.cpp
@@ -238,6 +238,11 @@ void TestDataset(const std::unique_ptr<asapo::Consumer>& consumer, const std::st
     M_AssertTrue(err == nullptr, "GetDatasetById error");
     M_AssertTrue(dataset.content[2].name == "8_3", "GetDatasetById filename");
 
+    auto size = consumer->GetCurrentDatasetCount("default", false, &err);
+    M_AssertTrue(err == nullptr, "GetCurrentDatasetCount no error");
+    M_AssertTrue(size == 10, "GetCurrentDatasetCount size");
+
+
 // incomplete datasets without min_size
 
     dataset = consumer->GetNextDataset(group_id, 0, "incomplete", &err);
@@ -271,6 +276,14 @@ void TestDataset(const std::unique_ptr<asapo::Consumer>& consumer, const std::st
     M_AssertTrue(err == nullptr, "GetDatasetById incomplete minsize error");
     M_AssertTrue(dataset.content[0].name == "2_1", "GetDatasetById incomplete minsize filename");
 
+    size = consumer->GetCurrentDatasetCount("incomplete", true, &err);
+    M_AssertTrue(err == nullptr, "GetCurrentDatasetCount including incomplete no error");
+    M_AssertTrue(size == 5, "GetCurrentDatasetCount including incomplete size");
+
+    size = consumer->GetCurrentDatasetCount("incomplete", false, &err);
+    M_AssertTrue(err == nullptr, "GetCurrentDatasetCount excluding incomplete no error");
+    M_AssertTrue(size == 0, "GetCurrentDatasetCount excluding incomplete size");
+
 
 }
 
diff --git a/tests/automatic/consumer/consumer_api_python/consumer_api.py b/tests/automatic/consumer/consumer_api_python/consumer_api.py
index 013ce0516..4e0915726 100644
--- a/tests/automatic/consumer/consumer_api_python/consumer_api.py
+++ b/tests/automatic/consumer/consumer_api_python/consumer_api.py
@@ -79,6 +79,14 @@ def check_single(consumer, group_id):
     size = consumer.get_current_size()
     assert_eq(size, 5, "get_current_size")
 
+    try:
+        size = consumer.get_current_dataset_count(include_incomplete = True)
+    except asapo_consumer.AsapoWrongInputError as err:
+        pass
+    else:
+        exit_on_noerr("get_current_dataset_count for single messages err")
+
+
     consumer.reset_lastread_marker(group_id)
 
     _, meta = consumer.get_next(group_id, meta_only=True)
@@ -269,6 +277,9 @@ def check_dataset(consumer, group_id):
     assert_eq(res['id'], 8, "get_dataset_by_id1 id")
     assert_metaname(res['content'][2], "8_3", "get get_dataset_by_id1 name3")
 
+    size = consumer.get_current_dataset_count()
+    assert_eq(size, 10, "get_current_dataset_count")
+
     # incomplete datesets without min_size given
     try:
         consumer.get_next_dataset(group_id, stream = "incomplete")
@@ -308,6 +319,14 @@ def check_dataset(consumer, group_id):
     res = consumer.get_dataset_by_id(2, min_size=1, stream = "incomplete")
     assert_eq(res['id'], 2, "get_dataset_by_id incomplete with minsize")
 
+    size = consumer.get_current_dataset_count(stream = "incomplete", include_incomplete = False)
+    assert_eq(size, 0, "get_current_dataset_count excluding incomplete")
+
+    size = consumer.get_current_dataset_count(stream = "incomplete", include_incomplete = True)
+    assert_eq(size, 5, "get_current_dataset_count including incomplete")
+
+    size = consumer.get_current_size(stream = "incomplete") # should work as well
+    assert_eq(size, 5, "get_current_size for datasets")
 
 source, path, beamtime, token, mode = sys.argv[1:]
 
-- 
GitLab