Skip to content

Commit cf1221a

Browse files
committed
Add the option to set the client name
1 parent cec3cda commit cf1221a

3 files changed

Lines changed: 14 additions & 5 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ vagrant/.vagrant
1414
.vscode/
1515
*.iml
1616
.pytest_cache/
17+
*.so

aredis/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, host='localhost', port=6379,
100100
ssl_cert_reqs=None, ssl_ca_certs=None,
101101
max_connections=None, retry_on_timeout=False,
102102
max_idle_time=0, idle_check_interval=1,
103+
client_name=None,
103104
loop=None, **kwargs):
104105
if not connection_pool:
105106
kwargs = {
@@ -113,6 +114,7 @@ def __init__(self, host='localhost', port=6379,
113114
'decode_responses': decode_responses,
114115
'max_idle_time': max_idle_time,
115116
'idle_check_interval': idle_check_interval,
117+
'client_name': client_name,
116118
'loop': loop
117119
}
118120
# based on input, setup appropriate connection args

aredis/connection.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class BaseConnection:
367367
def __init__(self, retry_on_timeout=False, stream_timeout=None,
368368
parser_class=DefaultParser, reader_read_size=65535,
369369
encoding='utf-8', decode_responses=False,
370-
*, loop=None):
370+
*, loop=None, client_name=None):
371371
self._parser = parser_class(reader_read_size)
372372
self._stream_timeout = stream_timeout
373373
self._reader = None
@@ -381,6 +381,7 @@ def __init__(self, retry_on_timeout=False, stream_timeout=None,
381381
self.encoding = encoding
382382
self.decode_responses = decode_responses
383383
self.loop = loop
384+
self.client_name = client_name
384385
# flag to show if a connection is waiting for response
385386
self.awaiting_response = False
386387
self.last_active_at = time.time()
@@ -442,6 +443,11 @@ async def on_connect(self):
442443
await self.send_command('SELECT', self.db)
443444
if nativestr(await self.read_response()) != 'OK':
444445
raise ConnectionError('Invalid Database')
446+
447+
if self.client_name:
448+
await self.send_command('CLIENT SETNAME', self.client_name)
449+
if nativestr(await self.read_response()) != 'OK':
450+
raise ConnectionError('Failed to set client name: {}'.format(self.client_name))
445451
self.last_active_at = time.time()
446452

447453
async def read_response(self):
@@ -571,11 +577,11 @@ def __init__(self, host='127.0.0.1', port=6379, password=None,
571577
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
572578
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
573579
encoding='utf-8', decode_responses=False, socket_keepalive=None,
574-
socket_keepalive_options=None, *, loop=None):
580+
socket_keepalive_options=None, *, loop=None, client_name=None):
575581
super(Connection, self).__init__(retry_on_timeout, stream_timeout,
576582
parser_class, reader_read_size,
577583
encoding, decode_responses,
578-
loop=loop)
584+
loop=loop, client_name=client_name)
579585
self.host = host
580586
self.port = port
581587
self.password = password
@@ -624,11 +630,11 @@ class UnixDomainSocketConnection(BaseConnection):
624630
def __init__(self, path='', password=None,
625631
db=0, retry_on_timeout=False, stream_timeout=None, connect_timeout=None,
626632
ssl_context=None, parser_class=DefaultParser, reader_read_size=65535,
627-
encoding='utf-8', decode_responses=False, *, loop=None):
633+
encoding='utf-8', decode_responses=False, *, loop=None, client_name=None):
628634
super(UnixDomainSocketConnection, self).__init__(retry_on_timeout, stream_timeout,
629635
parser_class, reader_read_size,
630636
encoding, decode_responses,
631-
loop=loop)
637+
loop=loop, client_name=client_name)
632638
self.path = path
633639
self.db = db
634640
self.password = password

0 commit comments

Comments
 (0)