diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc index a38c086897d..6a90d39df3b 100644 --- a/src/cluster/cluster.cc +++ b/src/cluster/cluster.cc @@ -824,7 +824,7 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const { } Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, - redis::Connection *conn) { + redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx) { std::vector keys_indexes; // No keys @@ -849,6 +849,21 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons return {Status::RedisClusterDown, "Hash slot not served"}; } + bool cross_slot_ok = false; + if (script_run_ctx) { + if (script_run_ctx->current_slot != -1 && script_run_ctx->current_slot != slot) { + if (getNodeIDBySlot(script_run_ctx->current_slot) != getNodeIDBySlot(slot)) { + return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + } + if (!(script_run_ctx->flags & lua::ScriptFlagType::kScriptAllowCrossSlotKeys)) { + return {Status::RedisCrossSlot, "Script attempted to access keys that do not hash to the same slot"}; + } + } + + script_run_ctx->current_slot = slot; + cross_slot_ok = true; + } + if (myself_ && myself_ == slots_nodes_[slot]) { // We use central controller to manage the topology of the cluster. // Server can't change the topology directly, so we record the migrated slots @@ -887,7 +902,11 @@ Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes, cons return Status::OK(); // My master is serving this slot } - return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + if (!cross_slot_ok) { + return {Status::RedisMoved, fmt::format("{} {}:{}", slot, slots_nodes_[slot]->host, slots_nodes_[slot]->port)}; + } + + return Status::OK(); } // Only HARD mode is meaningful to the Kvrocks cluster, diff --git a/src/cluster/cluster.h b/src/cluster/cluster.h index e595666c75e..468c154d4d8 100644 --- a/src/cluster/cluster.h +++ b/src/cluster/cluster.h @@ -35,6 +35,7 @@ #include "redis_slot.h" #include "server/redis_connection.h" #include "status.h" +#include "storage/scripting.h" class ClusterNode { public: @@ -83,7 +84,7 @@ class Cluster { bool IsNotMaster(); bool IsWriteForbiddenSlot(int slot) const; Status CanExecByMySelf(const redis::CommandAttributes *attributes, const std::vector &cmd_tokens, - redis::Connection *conn); + redis::Connection *conn, lua::ScriptRunCtx *script_run_ctx = nullptr); Status SetMasterSlaveRepl(); Status MigrateSlotRange(const SlotRange &slot_range, const std::string &dst_node_id, SyncMigrateContext *blocking_ctx = nullptr); diff --git a/src/server/worker.cc b/src/server/worker.cc index 22054e1faf8..0b37bcef30d 100644 --- a/src/server/worker.cc +++ b/src/server/worker.cc @@ -72,7 +72,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv), base_(event_base_new()) LOG(INFO) << "[worker] Listening on: " << bind << ":" << *port; } } - lua_ = lua::CreateState(srv, true); + lua_ = lua::CreateState(srv); } Worker::~Worker() { diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index 987f9f0dfc0..bfe1016ee70 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -57,11 +57,11 @@ enum { namespace lua { -lua_State *CreateState(Server *srv, bool read_only) { +lua_State *CreateState(Server *srv) { lua_State *lua = lua_open(); LoadLibraries(lua); RemoveUnsupportedFunctions(lua); - LoadFuncs(lua, read_only); + LoadFuncs(lua); lua_pushlightuserdata(lua, srv); lua_setglobal(lua, REDIS_LUA_SERVER_PTR); @@ -75,7 +75,7 @@ void DestroyState(lua_State *lua) { lua_close(lua); } -void LoadFuncs(lua_State *lua, bool read_only) { +void LoadFuncs(lua_State *lua) { lua_newtable(lua); /* redis.call */ @@ -127,11 +127,6 @@ void LoadFuncs(lua_State *lua, bool read_only) { lua_pushcfunction(lua, RedisStatusReplyCommand); lua_settable(lua, -3); - /* redis.read_only */ - lua_pushstring(lua, "read_only"); - lua_pushboolean(lua, read_only); - lua_settable(lua, -3); - /* redis.register_function */ lua_pushstring(lua, "register_function"); lua_pushcfunction(lua, RedisRegisterFunction); @@ -226,8 +221,8 @@ int RedisLogCommand(lua_State *lua) { int RedisRegisterFunction(lua_State *lua) { int argc = lua_gettop(lua); - if (argc != 2) { - lua_pushstring(lua, "redis.register_function() requires two arguments."); + if (argc < 2 || argc > 3) { + lua_pushstring(lua, "wrong number of arguments to redis.register_function()."); return lua_error(lua); } @@ -243,6 +238,15 @@ int RedisRegisterFunction(lua_State *lua) { // set this function to global std::string name = lua_tostring(lua, 1); + if (argc == 3) { + auto flags = ExtractFlagsFromRegisterFunction(lua); + if (!flags) { + lua_pushstring(lua, flags.Msg().c_str()); + return lua_error(lua); + } + lua_pushinteger(lua, static_cast(flags.GetValue())); + lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); + } lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); // set this function name to REDIS_FUNCTION_LIBRARIES[libname] @@ -274,7 +278,6 @@ int RedisRegisterFunction(lua_State *lua) { lua_pushstring(lua, "redis.register_function() failed to store informantion."); return lua_error(lua); } - return 0; } @@ -288,31 +291,8 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee return {Status::NotOK, "Expect a Shebang statement in the first line"}; } - static constexpr const char *shebang_prefix = "#!lua"; - static constexpr const char *shebang_libname_prefix = "name="; - - auto first_line_split = util::Split(first_line, " \r\t"); - if (first_line_split.empty() || first_line_split[0] != shebang_prefix) { - return {Status::NotOK, "Expect a Shebang statement in the first line, e.g. `#!lua name=mylib`"}; - } - - size_t libname_pos = 1; - for (; libname_pos < first_line_split.size(); ++libname_pos) { - if (util::HasPrefix(first_line_split[libname_pos], shebang_libname_prefix)) { - break; - } - } - - if (libname_pos >= first_line_split.size()) { - return {Status::NotOK, "Expect library name in the Shebang statement, e.g. `#!lua name=mylib`"}; - } + const auto libname = GET_OR_RET(ExtractLibNameFromShebang(first_line)); - auto libname = first_line_split[libname_pos].substr(strlen(shebang_libname_prefix)); - *lib_name = libname; - if (libname.empty() || - std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { - return {Status::NotOK, "Expect a valid library name in the Shebang statement"}; - } auto srv = conn->GetServer(); auto lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -409,13 +389,24 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: std::string libcode; s = srv->FunctionGetCode(libname, &libcode); if (!s) return s; - s = FunctionLoad(conn, libcode, false, false, &libname, read_only); if (!s) return s; lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str()); } + ScriptRunCtx script_run_ctx; + script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; + lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str()); + if (!lua_isnil(lua, -1)) { + // It should be ensured that the conversion is successful + auto function_flags = lua_tointeger(lua, -1); + script_run_ctx.flags |= function_flags; + } + lua_pop(lua, 1); + + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx); + PushArray(lua, keys); PushArray(lua, argv); if (lua_pcall(lua, 2, 1, -4)) { @@ -427,6 +418,22 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: lua_pop(lua, 2); } + RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + + /* Call the Lua garbage collector from time to time to avoid a + * full cycle performed by Lua, which adds too latency. + * + * The call is performed every LUA_GC_CYCLE_PERIOD executed commands + * (and for LUA_GC_CYCLE_PERIOD collection steps) because calling it + * for every command uses too much CPU. */ + constexpr int64_t LUA_GC_CYCLE_PERIOD = 50; + static int64_t gc_count = 0; + + gc_count++; + if (gc_count == LUA_GC_CYCLE_PERIOD) { + lua_gc(lua, LUA_GCSTEP, LUA_GC_CYCLE_PERIOD); + gc_count = 0; + } return Status::OK(); } @@ -572,6 +579,8 @@ Status FunctionDelete(Server *srv, const std::string &name) { std::string func = lua_tostring(lua, -1); lua_pushnil(lua); lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + func).c_str()); + lua_pushnil(lua); + lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + func).c_str()); auto _ = storage->Delete(rocksdb::WriteOptions(), cf, engine::kLuaFuncLibPrefix + func); lua_pop(lua, 1); } @@ -590,7 +599,6 @@ Status FunctionDelete(Server *srv, const std::string &name) { Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sha, const std::vector &keys, const std::vector &argv, bool evalsha, std::string *output, bool read_only) { Server *srv = conn->GetServer(); - // Use the worker's private Lua VM when entering the read-only mode lua_State *lua = read_only ? conn->Owner()->Lua() : srv->Lua(); @@ -635,6 +643,18 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh lua_getglobal(lua, funcname); } + ScriptRunCtx current_script_run_ctx; + current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0; + lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str()); + if (!lua_isnil(lua, -1)) { + // It should be ensured that the conversion is successful + auto script_flags = lua_tointeger(lua, -1); + current_script_run_ctx.flags |= script_flags; + } + lua_pop(lua, 1); + + SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, ¤t_script_run_ctx); + // For the Lua script, should be always run with RESP2 protocol, // unless users explicitly set the protocol version in the script via `redis.setresp`. // So we need to save the current protocol version and set it to RESP2, @@ -661,6 +681,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh lua_setglobal(lua, "KEYS"); lua_pushnil(lua); lua_setglobal(lua, "ARGV"); + RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); /* Call the Lua garbage collector from time to time to avoid a * full cycle performed by Lua, which adds too latency. @@ -701,10 +722,8 @@ Server *GetServer(lua_State *lua) { // TODO: we do not want to repeat same logic as Connection::ExecuteCommands, // so the function need to be refactored int RedisGenericCommand(lua_State *lua, int raise_error) { - lua_getglobal(lua, "redis"); - lua_getfield(lua, -1, "read_only"); - int read_only = lua_toboolean(lua, -1); - lua_pop(lua, 2); + auto *script_run_ctx = GetFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME); + CHECK_NOTNULL(script_run_ctx); int argc = lua_gettop(lua); if (argc == 0) { @@ -738,7 +757,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { auto attributes = cmd->GetAttributes(); auto cmd_flags = attributes->GenerateFlags(args); - if (read_only && !(cmd_flags & redis::kCmdReadOnly)) { + if ((script_run_ctx->flags & ScriptFlagType::kScriptNoWrites) && !(cmd_flags & redis::kCmdReadOnly)) { PushError(lua, "Write commands are not allowed from read-only scripts"); return raise_error ? RaiseError(lua) : 1; } @@ -760,9 +779,17 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { redis::Connection *conn = srv->GetCurrentConnection(); if (config->cluster_enabled) { - auto s = srv->cluster->CanExecByMySelf(attributes, args, conn); + if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) { + PushError(lua, "Can not run script on cluster, 'no-cluster' flag is set"); + return raise_error ? RaiseError(lua) : 1; + } + auto s = srv->cluster->CanExecByMySelf(attributes, args, conn, script_run_ctx); if (!s.IsOK()) { - PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); + if (s.Is()) { + PushError(lua, "Script attempted to access a non local key in a cluster node script"); + } else { + PushError(lua, redis::StatusToRedisErrorMsg(s).c_str()); + } return raise_error ? RaiseError(lua) : 1; } } @@ -1463,7 +1490,23 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu std::copy(sha->begin(), sha->end(), funcname + 2); } - if (luaL_loadbuffer(lua, body.c_str(), body.size(), "@user_script")) { + std::string_view lua_code(body); + // Cache the flags of the current script + ScriptFlags script_flags = 0; + if (auto pos = body.find('\n'); pos != std::string::npos) { + std::string_view first_line(body.data(), pos); + if (first_line.substr(0, 2) == "#!") { + lua_code = lua_code.substr(pos + 1); + } + script_flags = GET_OR_RET(ExtractFlagsFromShebang(first_line)); + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots + script_flags = kScriptAllowCrossSlotKeys; + } + lua_pushinteger(lua, static_cast(script_flags)); + lua_setglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, *sha).c_str()); + + if (luaL_loadbuffer(lua, lua_code.data(), lua_code.size(), "@user_script")) { std::string err_msg = lua_tostring(lua, -1); lua_pop(lua, 1); return {Status::NotOK, "Error while compiling new script: " + err_msg}; @@ -1474,4 +1517,123 @@ Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lu return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK(); } +[[nodiscard]] StatusOr ExtractLibNameFromShebang(std::string_view shebang) { + static constexpr std::string_view lua_shebang_prefix = "#!lua"; + static constexpr std::string_view shebang_libname_prefix = "name="; + + if (shebang.substr(0, 2) != "#!") { + return {Status::NotOK, "Missing library meta"}; + } + + auto shebang_splits = util::Split(shebang, " "); + if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) { + return {Status::NotOK, "Unexpected engine in script shebang: " + shebang_splits[0]}; + } + + std::string libname; + bool found_libname = false; + for (size_t i = 1; i < shebang_splits.size(); i++) { + std::string_view shebang_split_sv = shebang_splits[i]; + if (shebang_split_sv.substr(0, shebang_libname_prefix.size()) != shebang_libname_prefix) { + return {Status::NotOK, "Unknown lua shebang option: " + shebang_splits[i]}; + } + if (found_libname) { + return {Status::NotOK, "Redundant library name in script shebang"}; + } + + libname = shebang_split_sv.substr(shebang_libname_prefix.size()); + if (libname.empty() || + std::any_of(libname.begin(), libname.end(), [](char v) { return !std::isalnum(v) && v != '_'; })) { + return { + Status::NotOK, + "Library names can only contain letters, numbers, or underscores(_) and must be at least one character long"}; + } + found_libname = true; + } + + if (found_libname) return libname; + return {Status::NotOK, "Library name was not given"}; +} + +[[nodiscard]] StatusOr GetFlagsFromStrings(const std::vector &flags_content) { + ScriptFlags flags = 0; + for (const auto &flag : flags_content) { + if (flag == "no-writes") { + flags |= kScriptNoWrites; + } else if (flag == "allow-oom") { + return {Status::NotSupported, "allow-oom is not supported yet"}; + } else if (flag == "allow-stale") { + return {Status::NotSupported, "allow-stale is not supported yet"}; + } else if (flag == "no-cluster") { + flags |= kScriptNoCluster; + } else if (flag == "allow-cross-slot-keys") { + flags |= kScriptAllowCrossSlotKeys; + } else { + return {Status::NotOK, "Unknown flag given: " + flag}; + } + } + return flags; +} + +[[nodiscard]] StatusOr ExtractFlagsFromShebang(std::string_view shebang) { + static constexpr std::string_view lua_shebang_prefix = "#!lua"; + static constexpr std::string_view shebang_flags_prefix = "flags="; + + ScriptFlags result_flags = 0; + if (shebang.substr(0, 2) == "#!") { + auto shebang_splits = util::Split(shebang, " "); + if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) { + return {Status::NotOK, "Unexpected engine in script shebang: " + shebang_splits[0]}; + } + bool found_flags = false; + for (size_t i = 1; i < shebang_splits.size(); i++) { + std::string_view shebang_split_sv = shebang_splits[i]; + if (shebang_split_sv.substr(0, shebang_flags_prefix.size()) != shebang_flags_prefix) { + return {Status::NotOK, "Unknown lua shebang option: " + shebang_splits[i]}; + } + if (found_flags) { + return {Status::NotOK, "Redundant flags in script shebang"}; + } + auto flags_content = util::Split(shebang_split_sv.substr(shebang_flags_prefix.size()), ","); + result_flags |= GET_OR_RET(GetFlagsFromStrings(flags_content)); + found_flags = true; + } + } else { + // scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + result_flags = kScriptAllowCrossSlotKeys; + } + + return result_flags; +} + +[[nodiscard]] StatusOr ExtractFlagsFromRegisterFunction(lua_State *lua) { + if (!lua_istable(lua, -1)) { + return {Status::NotOK, "Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' }"}; + } + auto flag_count = static_cast(lua_objlen(lua, -1)); + std::vector flags_content; + flags_content.reserve(flag_count); + for (int i = 1; i <= flag_count; ++i) { + lua_pushnumber(lua, i); + lua_gettable(lua, -2); + if (!lua_isstring(lua, -1)) { + return {Status::NotOK, "Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' }"}; + } + flags_content.emplace_back(lua_tostring(lua, -1)); + // pop up the current flag + lua_pop(lua, 1); + } + // pop up the corresponding table of the flags parameter + lua_pop(lua, 1); + + return GetFlagsFromStrings(flags_content); +} + +void RemoveFromRegistry(lua_State *lua, const char *name) { + lua_pushstring(lua, name); + lua_pushnil(lua); + lua_settable(lua, LUA_REGISTRYINDEX); +} + } // namespace lua diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 3b2dd45deef..6cfa31f066b 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -28,19 +28,22 @@ #include "status.h" inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_"; +inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_"; inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = "__redis_registered_"; +inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = "__redis_registered_flags_"; inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr"; inline constexpr const char REDIS_FUNCTION_LIBNAME[] = "REDIS_FUNCTION_LIBNAME"; inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = "REDIS_FUNCTION_NEEDSTORE"; inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = "REDIS_FUNCTION_LIBRARIES"; +inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX"; namespace lua { -lua_State *CreateState(Server *srv, bool read_only = false); +lua_State *CreateState(Server *srv); void DestroyState(lua_State *lua); Server *GetServer(lua_State *lua); -void LoadFuncs(lua_State *lua, bool read_only = false); +void LoadFuncs(lua_State *lua); void LoadLibraries(lua_State *lua); void RemoveUnsupportedFunctions(lua_State *lua); void EnableGlobalsProtection(lua_State *lua); @@ -101,4 +104,87 @@ void SHA1Hex(char *digest, const char *script, size_t len); int RedisMathRandom(lua_State *l); int RedisMathRandomSeed(lua_State *l); +/// ScriptFlagType turn on/off constraints or indicate properties in Eval scripts and functions +/// +/// Note: The default for Eval scripts are different than the default for functions(default is 0). +/// As soon as Redis sees the #! comment, it'll treat the script as if it declares flags, even if no flags are defined, +/// it still has a different set of defaults compared to a script without a #! line. +/// Another difference is that scripts without #! can run commands that access keys belonging to different cluster hash +/// slots, but ones with #! inherit the default flags, so they cannot. +enum ScriptFlagType : uint64_t { + kScriptNoWrites = 1ULL << 0, // "no-writes" flag + kScriptAllowOom = 1ULL << 1, // "allow-oom" flag + kScriptAllowStale = 1ULL << 2, // "allow-stale" flag + kScriptNoCluster = 1ULL << 3, // "no-cluster" flag + kScriptAllowCrossSlotKeys = 1ULL << 4, // "allow-cross-slot-keys" flag +}; + +/// ScriptFlags is composed of one or more ScriptFlagTypes combined by an OR operation +/// For example, ScriptFlags flags = kScriptNoWrites | kScriptNoCluster +using ScriptFlags = uint64_t; + +[[nodiscard]] StatusOr ExtractLibNameFromShebang(std::string_view shebang); +[[nodiscard]] StatusOr ExtractFlagsFromShebang(std::string_view shebang); + +/// GetFlagsFromStrings gets flags from flags_content and composites them together. +/// Each element in flags_content should correspond to a string form of ScriptFlagType +[[nodiscard]] StatusOr GetFlagsFromStrings(const std::vector &flags_content); + +/// ExtractFlagsFromRegisterFunction extracts the flags from the redis.register_function +/// +/// Note: When using it, you should make sure that +/// the top of the stack of lua is the flags parameter of redis.register_function. +/// The flags parameter in Lua is a table that stores strings. +/// After use, the original flags table on the top of the stack will be popped. +[[nodiscard]] StatusOr ExtractFlagsFromRegisterFunction(lua_State *lua); + +/// ScriptRunCtx is used to record context information during the running of Eval scripts and functions. +struct ScriptRunCtx { + // ScriptFlags + uint64_t flags = 0; + // current_slot tracks the slot currently accessed by the script + // and is used to detect whether there is cross-slot access + // between multiple commands in a script or function. + int current_slot = -1; +}; + +/// SaveOnRegistry saves user-defined data to lua REGISTRY +/// +/// Note: Since lua_pushlightuserdata, you need to manage the life cycle of the data stored in the Registry yourself. +template +void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) { + lua_pushstring(lua, name); + if (ptr) { + lua_pushlightuserdata(lua, ptr); + } else { + lua_pushnil(lua); + } + lua_settable(lua, LUA_REGISTRYINDEX); +} + +template +T *GetFromRegistry(lua_State *lua, const char *name) { + lua_pushstring(lua, name); + lua_gettable(lua, LUA_REGISTRYINDEX); + + if (lua_isnil(lua, -1)) { + // pops the value + lua_pop(lua, 1); + return nullptr; + } + + // must be light user data + CHECK(lua_islightuserdata(lua, -1)); + auto *ptr = static_cast(lua_touserdata(lua, -1)); + + CHECK_NOTNULL(ptr); + + // pops the value + lua_pop(lua, 1); + + return ptr; +} + +void RemoveFromRegistry(lua_State *lua, const char *name); + } // namespace lua diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index d0c12ec157a..8d71158498a 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -22,6 +22,7 @@ package scripting import ( "context" _ "embed" + "fmt" "strings" "testing" @@ -116,13 +117,13 @@ var testFunctions = func(t *testing.T, enabledRESP3 string) { t.Run("FUNCTION LOAD errors", func(t *testing.T) { code := strings.Join(strings.Split(luaMylib1, "\n")[1:], "\n") - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), ".*Shebang statement.*") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(), "ERR Missing library meta") code2 := "#!lua\n" + code - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), ".*Expect library name.*") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Library name was not given") code2 = "#!lua name=$$$\n" + code - util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), ".*valid library name.*") + require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(), "ERR Library names can only contain letters, numbers, or underscores(_) and must be at least one character long") }) t.Run("FUNCTION LOAD and FCALL mylib1", func(t *testing.T) { @@ -277,3 +278,269 @@ var testFunctions = func(t *testing.T, enabledRESP3 string) { }, decodeListLibResult(t, r)) }) } + +func TestFunctionScriptFlags(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Function extract-libname-error", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname flags=no-writes + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua flags=no-writes name=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname name=mylibname2 + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "Redundant library name in script shebang") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!errorenine name=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!luaname=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua xxxname=mylibname + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mylibname key=value + redis.register_function('extract_libname_error_func', function(keys, args) end)`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + }) + + t.Run("Function extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 'invalid-flag' })`) + require.Error(t, r.Err(), "ERR Error while running new function lib: Unknown flag given: invalid-flag") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 'no-writes', 'invalid-flag' })`) + require.Error(t, r.Err(), "ERR Error while running new function lib: Unknown flag given: invalid-flag") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { {} }`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, { 123 }`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=myflags + redis.register_function('extract_flags_error_func', function(keys, args) end, 'no-writes'`) + require.Error(t, r.Err(), "ERR Expects a valid flags argument to register_function, e.g. flags={ 'no-writes' })") + }) + + t.Run("no-writes", func(t *testing.T) { + r := rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=nowriteslib + redis.register_function('default_flag_func', function(keys, args) return redis.call("set", keys[1], args[1]) end) + redis.register_function('no_writes_func', function(keys, args) return redis.call("set", keys[1], args[1]) end, { 'no-writes' })`) + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "FCALL", "default_flag_func", 1, "k1", "v1") + require.NoError(t, r.Err()) + r = rdb.Do(ctx, "FCALL", "no_writes_func", 1, "k2", "v2") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "FCALL_RO", "default_flag_func", 1, "k1", "v1") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + r = rdb.Do(ctx, "FCALL_RO", "no_writes_func", 1, "k2", "v2") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + }) + + srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + defer func() { srv0.Close() }() + id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err()) + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + srv1Alive := true + defer func() { + if srv1Alive { + srv1.Close() + } + }() + + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0, srv0.Host(), srv0.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1, srv1.Host(), srv1.Port()) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + // Node0: bar-slot = 5061, test-slot = 6918 + // Node1: foo-slot = 12182 + // Different slots of different nodes are not affected by allow-cross-slot-keys, + // and different slots of the same node can be allowed + require.NoError(t, rdb0.Set(ctx, "bar", "bar_value", 0).Err()) + require.NoError(t, rdb0.Set(ctx, "test", "test_value", 0).Err()) + require.NoError(t, rdb1.Set(ctx, "foo", "foo_value", 0).Err()) + + t.Run("no-cluster", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=noclusterlib + redis.register_function('default_flag_func', function(keys) return redis.call('get', keys[1]) end) + redis.register_function('no_cluster_func', function(keys) return redis.call('get', keys[1]) end, { 'no-cluster' })`) + require.NoError(t, r.Err()) + + require.NoError(t, rdb0.Do(ctx, "FCALL", "default_flag_func", 1, "bar").Err()) + + r = rdb0.Do(ctx, "FCALL", "no_cluster_func", 1, "bar") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // Only valid in cluster mode + require.NoError(t, rdb.Set(ctx, "bar", "rdb_bar_value", 0).Err()) + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=noclusterlib + redis.register_function('no_cluster_func', function(keys) return redis.call('get', keys[1]) end, { 'no-cluster' })`) + require.NoError(t, r.Err()) + require.NoError(t, rdb.Do(ctx, "FCALL", "no_cluster_func", 1, "bar").Err()) + }) + + t.Run("allow-cross-slot-keys", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=crossslotlib + redis.register_function('default_flag_func_1', + function() + redis.call('get', 'bar') + return redis.call('get', 'test') + end + ) + + redis.register_function('default_flag_func_2', + function() + redis.call('get', 'bar') + return redis.call('get', 'foo') + end + ) + + redis.register_function('default_flag_func_3', + function(keys) + redis.call('get', keys[1]) + return redis.call('get', keys[2]) + end + ) + + redis.register_function( + 'allow_cross_slot_keys_func_1', + function() + redis.call('get', 'bar') + return redis.call('get', 'test') + end, + { 'allow-cross-slot-keys' }) + + redis.register_function( + 'allow_cross_slot_keys_func_2', + function() + redis.call('get', 'bar') + return redis.call('get', 'foo') + end, + { 'allow-cross-slot-keys' }) + + redis.register_function( + 'allow_cross_slot_keys_func_3', + function(keys) + redis.call('get', key[1]) + return redis.call('get', key[2]) + end, + { 'allow-cross-slot-keys' }) + + `) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_1", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access keys that do not hash to the same slot") + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + r = rdb0.Do(ctx, "FCALL", "default_flag_func_3", 2, "bar", "test") + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_1", 0) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + // Pre-declared keys are not affected by allow-cross-slot-keys + r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_3", 2, "bar", "test") + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + }) + + t.Run("mixed-use", func(t *testing.T) { + r := rdb0.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mixeduselib + redis.register_function('no_write_cluster_func_1', function() redis.call('get', 'bar') end, { 'no-writes', 'no-cluster' }) + + redis.register_function('no_write_allow_cross_func_1', + function() redis.call('get', 'bar'); return redis.call('get', 'test'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + + redis.register_function('no_write_allow_cross_func_2', + function() redis.call('set', 'bar'); return redis.call('set', 'test'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + + redis.register_function('no_write_allow_cross_func_3', + function() redis.call('get', 'bar'); return redis.call('get', 'foo'); end, + { 'no-writes', 'allow-cross-slot-keys' }) + `) + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "FCALL", "no_write_cluster_func_1", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // no-cluster Only valid in cluster mode + r = rdb.Do(ctx, "FUNCTION", "LOAD", + `#!lua name=mixeduselib2 + redis.register_function('no_write_cluster_func_2', + function() return redis.call('set', 'bar', 'bar_value') end, + { 'no-writes', 'no-cluster' } + ) + + redis.register_function('no_write_cluster_func_3', + function() return redis.call('get', 'bar') end, + { 'no-writes', 'no-cluster' } + ) + `) + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "FCALL", "no_write_cluster_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + require.NoError(t, rdb.Set(ctx, "bar", "bar_value_rdb", 0).Err()) + require.NoError(t, rdb.Do(ctx, "FCALL", "no_write_cluster_func_3", 0).Err()) + + require.NoError(t, rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_1", 0).Err()) + r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_2", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_3", 0) + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + }) +} diff --git a/tests/gocase/unit/scripting/scripting_test.go b/tests/gocase/unit/scripting/scripting_test.go index 50680f7d63e..9e4f275239a 100644 --- a/tests/gocase/unit/scripting/scripting_test.go +++ b/tests/gocase/unit/scripting/scripting_test.go @@ -480,6 +480,7 @@ math.randomseed(ARGV[1]); return tostring(math.random()) t.Run("EVALSHA_RO - cannot run write commands", func(t *testing.T) { require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err()) + // sha1 of `redis.call('del', KEYS[1]);` r := rdb.Do(ctx, "EVALSHA_RO", "a1e63e1cd1bd1d5413851949332cfb9da4ee6dc0", "1", "foo") util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") }) @@ -607,3 +608,280 @@ func TestScriptingWithRESP3(t *testing.T) { require.EqualValues(t, []interface{}{"f1", "v1"}, vals) }) } + +func TestEvalScriptFlags(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Eval extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "EVAL", + `#!lua name=mylib + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes name=mylib + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua erroroption=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=invalid-flag + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,invalid-flag + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes no-cluster + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes flags=no-cluster + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script shebang") + + r = rdb.Do(ctx, "EVAL", + `#!errorengine flags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "EVAL", + `#!luaflags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "EVAL", + `#!lua xxflags=no-writes + return 'extract-flags'`, "0") + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + }) + + t.Run("SCRIPT LOAD extract-flags-error", func(t *testing.T) { + r := rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua name=mylib + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes name=mylib + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua erroroption=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=invalid-flag + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes,invalid-flag + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes no-cluster + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua flags=no-writes flags=no-cluster + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script shebang") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!errorengine flags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!luaflags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script shebang:*") + + r = rdb.Do(ctx, "SCRIPT", "LOAD", + `#!lua xxflags=no-writes + return 'extract-flags'`) + util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*") + }) + + t.Run("no-writes", func(t *testing.T) { + r := rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes + return redis.call('set', 'k1','v1');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL", `return redis.call('set', 'k2','v2');`, "0") + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "EVAL", + `#!lua + return redis.call('set', 'k3','v3');`, "0") + require.NoError(t, r.Err()) + + r = rdb.Do(ctx, "EVAL_RO", + `return redis.call('set', 'k4','v4');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL_RO", + `#!lua + return redis.call('set', 'k5','v5');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb.Do(ctx, "EVAL_RO", + `#!lua flags=no-writes + return redis.call('set', 'k6','v6');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + }) + + srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + rdb0 := srv0.NewClient() + defer func() { require.NoError(t, rdb0.Close()) }() + defer func() { srv0.Close() }() + id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err()) + + srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + srv1Alive := true + defer func() { + if srv1Alive { + srv1.Close() + } + }() + + rdb1 := srv1.NewClient() + defer func() { require.NoError(t, rdb1.Close()) }() + id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0, srv0.Host(), srv0.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1, srv1.Host(), srv1.Port()) + require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + t.Run("no-cluster", func(t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=no-cluster + return redis.call('set', 'k','v');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + // Only valid in cluster mode + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-cluster + return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + + // Scripts without #! can run commands that access keys belonging to different cluster hash slots, + // but ones with #! inherit the default flags, so they cannot. + r = rdb0.Do(ctx, "EVAL", `return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua + return redis.call('set', 'k','v');`, "0") + require.NoError(t, r.Err()) + }) + + t.Run("allow-cross-slot-keys", func(t *testing.T) { + // Node0: bar-slot = 5061, test-slot = 6918 + // Node1: foo-slot = 12182 + // Different slots of different nodes are not affected by allow-cross-slot-keys, + // and different slots of the same node can be allowed + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + redis.call('set', 'bar','value_bar'); + return redis.call('set', 'test', 'value_test');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + redis.call('set', 'foo','value_foo'); + return redis.call('set', 'bar', 'value_bar');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + // There is a shebang prefix #!lua but crossslot is not allowed when flags are not set + r = rdb0.Do(ctx, "EVAL", + `#!lua + redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access keys that do not hash to the same slot") + + r = rdb0.Do(ctx, "EVAL", + `#!lua + redis.call('get', 'foo'); + return redis.call('get', 'bar');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + // Old style: CrossSlot is allowed when there is neither #!lua nor flags set + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `redis.call('get', 'foo'); + return redis.call('get', 'bar');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + // Pre-declared keys are not affected by allow-cross-slot-keys + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=allow-cross-slot-keys + local key = redis.call('get', KEY[1]); + return redis.call('get', KEY[2]);`, "2", "bar", "test") + require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access keys that don't hash to the same slot") + }) + + t.Run("mixed use", func(t *testing.T) { + r := rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('get', 'key_a');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on cluster, 'no-cluster' flag is set") + + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('set', 'key_a', 'value_a');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + err := rdb.Set(ctx, "key_a", "value_a", 0).Err() + require.NoError(t, err) + r = rdb.Do(ctx, "EVAL", + `#!lua flags=no-writes,no-cluster + return redis.call('get', 'key_a');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('get', 'bar'); + return redis.call('get', 'test');`, "0") + require.NoError(t, r.Err()) + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('set', 'bar'); + return redis.call('set', 'test');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not allowed from read-only scripts") + + r = rdb0.Do(ctx, "EVAL", + `#!lua flags=no-writes,allow-cross-slot-keys + redis.call('get', 'bar'); + return redis.call('get', 'foo');`, "0") + util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access a non local key in a cluster node script") + + }) +}