diff --git a/broker/src/asapo_broker/database/encoding.go b/broker/src/asapo_broker/database/encoding.go index e1126fa844460bd53d688ea06323378d0755c084..578f08185cb00eea19b1a739dd30b3d76864f7fc 100644 --- a/broker/src/asapo_broker/database/encoding.go +++ b/broker/src/asapo_broker/database/encoding.go @@ -1,6 +1,9 @@ package database -import "net/url" +import ( + "asapo_common/utils" + "net/url" +) func shouldEscape(c byte, db bool) bool { if c == '$' || c == ' ' || c == '%' { @@ -74,7 +77,19 @@ func encodeStringForColName(original string) (result string) { func encodeRequest(request *Request) error { request.DbName = encodeStringForDbName(request.DbName) + if len(request.DbName)> max_encoded_source_size { + return &DBError{utils.StatusWrongInput, "source name is too long"} + } + request.DbCollectionName = encodeStringForColName(request.DbCollectionName) + if len(request.DbCollectionName)> max_encoded_stream_size { + return &DBError{utils.StatusWrongInput, "stream name is too long"} + } + request.GroupId = encodeStringForColName(request.GroupId) + if len(request.GroupId)> max_encoded_group_size { + return &DBError{utils.StatusWrongInput, "group id is too long"} + } + return nil } diff --git a/broker/src/asapo_broker/database/encoding_test.go b/broker/src/asapo_broker/database/encoding_test.go index 5f98641ef0d299133c1304aff476f4aa629ceab1..1def90c99f6a2268883530be39b62fbb01eabb96 100644 --- a/broker/src/asapo_broker/database/encoding_test.go +++ b/broker/src/asapo_broker/database/encoding_test.go @@ -1,13 +1,15 @@ package database import ( + "asapo_common/utils" "github.com/stretchr/testify/assert" + "math/rand" "testing" ) func TestEncoding(t *testing.T) { - stream:=`ss$` - source :=`ads%&%41.sss` + stream := `ss$` + source := `ads%&%41.sss` streamEncoded := encodeStringForColName(stream) sourceEncoded := encodeStringForDbName(source) streamDecoded := decodeString(streamEncoded) @@ -29,5 +31,52 @@ func TestEncoding(t *testing.T) { assert.Equal(t, r.GroupId, streamEncoded) assert.Equal(t, r.DbName, sourceEncoded) - assert.Nil(t,err) + assert.Nil(t, err) +} + +var encodeTests = []struct { + streamSize int + groupSize int + sourceSize int + ok bool + message string +}{ + {max_encoded_stream_size, max_encoded_group_size, max_encoded_source_size, true, "ok"}, + {max_encoded_stream_size + 1, max_encoded_group_size, max_encoded_source_size, false, "stream"}, + {max_encoded_stream_size, max_encoded_group_size + 1, max_encoded_source_size, false, "group"}, + {max_encoded_stream_size, max_encoded_group_size, max_encoded_source_size + 1, false, "source"}, +} + +func RandomString(n int) string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, n) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +func TestEncodingTooLong(t *testing.T) { + for _, test := range encodeTests { + stream := RandomString(test.streamSize) + group := RandomString(test.groupSize) + source := RandomString(test.sourceSize) + r := Request{ + DbName: source, + DbCollectionName: stream, + GroupId: group, + Op: "", + DatasetOp: false, + MinDatasetSize: 0, + ExtraParam: "", + } + err := encodeRequest(&r) + if test.ok { + assert.Nil(t, err, test.message) + } else { + assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code) + assert.Contains(t,err.Error(),test.message,test.message) + } + } } diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go index 75c885f30ee4290bf9d5980e2f23ebd25b4e9000..0e1f5ecc74047b3cd40325e50a62c415b5e10ba2 100644 --- a/broker/src/asapo_broker/database/mongodb.go +++ b/broker/src/asapo_broker/database/mongodb.go @@ -74,6 +74,10 @@ const stream_filter_all = "all" const stream_filter_finished = "finished" const stream_filter_unfinished = "unfinished" +const max_encoded_source_size = 63 +const max_encoded_stream_size = 100 +const max_encoded_group_size = 50 + var dbSessionLock sync.Mutex type SizeRecord struct { diff --git a/consumer/api/cpp/src/consumer_impl.cpp b/consumer/api/cpp/src/consumer_impl.cpp index 165e1339a50c9958fede78e9cec12e126ff4aef0..00ca9b96158f167159ff5817927a059a2f862fec 100644 --- a/consumer/api/cpp/src/consumer_impl.cpp +++ b/consumer/api/cpp/src/consumer_impl.cpp @@ -744,7 +744,7 @@ RequestInfo ConsumerImpl::GetStreamListRequest(const std::string& from, const St ri.api = UriPrefix("0", "", "streams"); ri.post = false; if (!from.empty()) { - ri.extra_params = "&from=" + from; + ri.extra_params = "&from=" + httpclient__->UrlEscape(from); } ri.extra_params += "&filter=" + filterToString(filter); return ri; diff --git a/consumer/api/cpp/unittests/test_consumer_impl.cpp b/consumer/api/cpp/unittests/test_consumer_impl.cpp index 71f80339cc4ae971746ab09e15a74c83d9125237..862f2f64e3c77921ca8680bf24c55838f9e4c773 100644 --- a/consumer/api/cpp/unittests/test_consumer_impl.cpp +++ b/consumer/api/cpp/unittests/test_consumer_impl.cpp @@ -1114,14 +1114,14 @@ TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUri) { R"({"lastId":124,"name":"test1","timestampCreated":2000000,"timestampLast":2000,"finished":true,"nextStream":"next"}]})"; EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/0/streams" - + "?token=" + expected_token + "&from=stream_from&filter=all", _, + + "?token=" + expected_token + "&from=" + expected_stream_encoded + "&filter=all", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(return_streams))); asapo::Error err; - auto streams = consumer->GetStreamList("stream_from", asapo::StreamFilter::kAllStreams, &err); + auto streams = consumer->GetStreamList(expected_stream, asapo::StreamFilter::kAllStreams, &err); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(streams.size(), Eq(2)); ASSERT_THAT(streams.size(), 2); diff --git a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp index 025a8649b44dac2901aaab7db9a2ee823bf87fe3..c6596be84bff2e03b63d2a50c4e8b271c5cdd622 100644 --- a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp +++ b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp @@ -3,7 +3,7 @@ #include <thread> #include "../../../common/cpp/src/database/mongodb_client.h" -#include "../../../common/cpp/src/database/encoding.h" +#include "asapo/database/db_error.h" #include "testing.h" #include "asapo/common/data_structs.h" @@ -34,6 +34,19 @@ Args GetArgs(int argc, char* argv[]) { return Args{argv[1], atoi(argv[2])}; } +std::string GenRandomString(int len) { + std::string s; + static const char alphanum[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + + for (int i = 0; i < len; ++i) { + s += alphanum[rand() % (sizeof(alphanum) - 1)]; + } + + return s; +} int main(int argc, char* argv[]) { auto args = GetArgs(argc, argv); @@ -47,6 +60,7 @@ int main(int argc, char* argv[]) { fi.buf_id = 18446744073709551615ull; fi.source = "host:1234"; + auto db_name = R"(data_/ \."$)"; auto stream_name = R"(bla/test_/\ ."$)"; @@ -118,5 +132,18 @@ int main(int argc, char* argv[]) { M_AssertTrue(err == nullptr); } + // long names + + asapo::MongoDBClient db1; + auto long_db_name = GenRandomString(64); + err = db1.Connect("127.0.0.1", long_db_name); + M_AssertTrue(err == asapo::DBErrorTemplates::kWrongInput); + + db1.Connect("127.0.0.1", db_name); + auto long_stream_name = GenRandomString(120); + err = db1.Insert(long_stream_name, fi, true); + M_AssertTrue(err == asapo::DBErrorTemplates::kWrongInput); + + return 0; } diff --git a/tests/automatic/producer/python_api/producer_api.py b/tests/automatic/producer/python_api/producer_api.py index b4061c3cf08e8127721437b49a2772c886f5c8a5..64e30bc439739f6e32d9481271c002739d9f717a 100644 --- a/tests/automatic/producer/python_api/producer_api.py +++ b/tests/automatic/producer/python_api/producer_api.py @@ -119,7 +119,7 @@ producer.wait_requests_finished(50000) # send to another stream producer.send(1, "processed/" + data_source + "/" + "file9", None, - ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream", callback=callback) + ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream/test $", callback=callback) # wait normal requests finished before sending duplicates @@ -149,7 +149,7 @@ assert_eq(n, 0, "requests in queue") # send another data to stream stream producer.send(2, "processed/" + data_source + "/" + "file10", None, - ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream", callback=callback) + ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream/test $", callback=callback) producer.wait_requests_finished(50000) n = producer.get_requests_queue_size() @@ -168,9 +168,9 @@ else: #stream_finished producer.wait_requests_finished(10000) -producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback) +producer.send_stream_finished_flag("stream/test $", 2, next_stream = "next_stream", callback = callback) # check callback_object.callback works, will be duplicated request -producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback_object.callback) +producer.send_stream_finished_flag("stream/test $", 2, next_stream = "next_stream", callback = callback_object.callback) producer.wait_requests_finished(10000) @@ -185,7 +185,7 @@ assert_eq(info['timestampLast']/1000000000>time.time()-10,True , "stream_info ti print("created: ",datetime.utcfromtimestamp(info['timestampCreated']/1000000000).strftime('%Y-%m-%d %H:%M:%S.%f')) print("last record: ",datetime.utcfromtimestamp(info['timestampLast']/1000000000).strftime('%Y-%m-%d %H:%M:%S.%f')) -info = producer.stream_info('stream') +info = producer.stream_info('stream/test $') assert_eq(info['lastId'], 3, "last id from different stream") assert_eq(info['finished'], True, "stream finished") @@ -199,12 +199,12 @@ assert_eq(info['lastId'], 0, "last id from non existing stream") info_last = producer.last_stream() print(info_last) -assert_eq(info_last['name'], "stream", "last stream") +assert_eq(info_last['name'], "stream/test $", "last stream") assert_eq(info_last['timestampCreated'] <= info_last['timestampLast'], True, "last is later than first") #delete_streams -producer.delete_stream('stream') -producer.stream_info('stream') +producer.delete_stream('stream/test $') +producer.stream_info('stream/test $') assert_eq(info['lastId'], 0, "last id from non deleted stream")