neondaniel commited on
Commit
640efa7
1 Parent(s): f1a3e74

Update to print disallowed endpoints in-place in the model list

Browse files

Update configuration handling to put all clients in `clients` with backwards-compat. parsing

Troubleshoot radio button rendering

Refactor permissions configuration to support other oauth methods

Files changed (2) hide show
  1. app.py +74 -34
  2. shared.py +16 -1
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import json
 
 
3
  import gradio as gr
4
 
5
  import uvicorn
@@ -11,7 +13,7 @@ from starlette.responses import RedirectResponse
11
  from authlib.integrations.starlette_client import OAuth, OAuthError
12
  from fastapi import FastAPI, Request
13
 
14
- from shared import Client
15
 
16
  app = FastAPI()
17
  config = {}
@@ -51,41 +53,44 @@ def init_config():
51
  global clients
52
  global llm_host_names
53
  config = json.loads(os.environ['CONFIG'])
54
- reserved_keys = ("huggingface_text", "allowed_domains_override")
55
- for name in config:
56
- if name in reserved_keys:
57
- continue
58
- model_personas = config[name].get("personas", {})
59
  client = Client(
60
- api_url=os.environ.get(config[name]['api_url'],
61
- config[name]['api_url']),
62
- api_key=os.environ.get(config[name]['api_key'],
63
- config[name]['api_key']),
64
  personas=model_personas
65
  )
66
  clients[name] = client
67
- llm_host_names = list(config.keys())
68
 
69
 
70
- def get_allowed_models(user_domain: str) -> List[str]:
71
  """
72
  Get a list of allowed endpoints for a specified user domain. Allowed domains
73
  are configured in each model's configuration and may optionally be overridden
74
  in the Gradio demo configuration.
75
- :param user_domain: User domain (i.e. neon.ai, google.com, guest)
76
- :return: List of allowed endpoints from configuration
 
77
  """
78
- overrides = config.get("allowed_domains_override", {})
79
  allowed_endpoints = []
80
  for client in clients:
81
- allowed_domains = overrides.get(client,
82
- clients[client].config.inference.allowed_domains)
83
- if allowed_domains is None:
84
- # Allowed domains not specified; model is public
85
  allowed_endpoints.append(client)
86
- elif user_domain in allowed_domains:
87
- # User domain is in the allowed domain list
 
88
  allowed_endpoints.append(client)
 
 
 
89
  return allowed_endpoints
90
 
91
 
@@ -107,7 +112,7 @@ def get_login_button(request: gr.Request) -> gr.Button:
107
  :param request: Gradio request to evaluate
108
  :return: Button for either login or logout action
109
  """
110
- user = get_user(request)
111
  print(f"Getting login button for {user}")
112
 
113
  if user == "guest":
@@ -116,15 +121,39 @@ def get_login_button(request: gr.Request) -> gr.Button:
116
  return gr.Button(f"Logout {user}", link="/logout")
117
 
118
 
119
- def get_user(request: Request) -> str:
120
  """
121
  Get a unique user email address for the specified request
122
  :param request: FastAPI Request object with user session data
123
  :return: String user email address or "guest"
124
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if not request:
126
- return "guest"
127
- user = request.session.get('user', {}).get('email') or "guest"
 
 
 
 
 
 
 
 
 
128
  return user
129
 
130
 
@@ -232,25 +261,25 @@ def get_model_options(request: gr.Request) -> List[gr.Radio]:
232
  # `user` is a valid Google email address or 'guest'
233
  user = get_user(request.request)
234
  else:
235
- user = "guest"
236
- print(f"Getting models for {user}")
237
 
238
- domain = "guest" if user == "guest" else user.split('@')[1]
239
- allowed_llm_host_names = get_allowed_models(domain)
240
 
241
  radio_infos = [f"{name} ({clients[name].vllm_model_name})"
 
242
  for name in allowed_llm_host_names]
243
  # Components
244
- radios = [gr.Radio(choices=clients[name].personas.keys(),
245
  value=None, label=info) for name, info
246
  in zip(allowed_llm_host_names, radio_infos)]
247
 
248
  # Select the first available option by default
249
  radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
250
  print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
251
- # Ensure we always have the same number of rows
252
- while len(radios) < len(llm_host_names):
253
- radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
254
  return radios
255
 
256
 
@@ -271,6 +300,17 @@ def init_gradio() -> gr.Blocks:
271
  @gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
272
  inputs=[radio_state, *radios], outputs=[radio_state, *radios])
273
  def radio_click(state, *new_state):
 
 
 
 
 
 
 
 
 
 
 
274
  try:
275
  changed_index = next(i for i in range(len(state))
276
  if state[i] != new_state[i])
@@ -326,5 +366,5 @@ if __name__ == "__main__":
326
  init_config()
327
  init_oauth()
328
  blocks = init_gradio()
329
- app = gr.mount_gradio_app(app, blocks, '/', auth_dependency=get_user)
330
  uvicorn.run(app, host='0.0.0.0', port=7860)
 
1
  import os
2
  import json
3
+ from time import sleep
4
+
5
  import gradio as gr
6
 
7
  import uvicorn
 
13
  from authlib.integrations.starlette_client import OAuth, OAuthError
14
  from fastapi import FastAPI, Request
15
 
16
+ from shared import Client, User, OAuthProvider
17
 
18
  app = FastAPI()
19
  config = {}
 
53
  global clients
54
  global llm_host_names
55
  config = json.loads(os.environ['CONFIG'])
56
+ client_config = config.get("clients") or config
57
+ for name in client_config:
58
+ model_personas = client_config[name].get("personas", {})
 
 
59
  client = Client(
60
+ api_url=os.environ.get(client_config[name]['api_url'],
61
+ client_config[name]['api_url']),
62
+ api_key=os.environ.get(client_config[name]['api_key'],
63
+ client_config[name]['api_key']),
64
  personas=model_personas
65
  )
66
  clients[name] = client
67
+ llm_host_names = list(client_config.keys())
68
 
69
 
70
+ def get_allowed_models(user: User) -> List[str]:
71
  """
72
  Get a list of allowed endpoints for a specified user domain. Allowed domains
73
  are configured in each model's configuration and may optionally be overridden
74
  in the Gradio demo configuration.
75
+ :param user: User to get permissions for
76
+ :return: List of allowed endpoints from configuration (including empty
77
+ strings for disallowed endpoints)
78
  """
79
+ overrides = config.get("permissions_override", {})
80
  allowed_endpoints = []
81
  for client in clients:
82
+ permission = overrides.get(client,
83
+ clients[client].config.inference.permissions)
84
+ if not permission:
85
+ # Permissions not specified (None or empty dict); model is public
86
  allowed_endpoints.append(client)
87
+ elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \
88
+ permission.get("google_domains", []):
89
+ # Google oauth domain is in the allowed domain list
90
  allowed_endpoints.append(client)
91
+ else:
92
+ allowed_endpoints.append("")
93
+ print(f"No permission to access {client}")
94
  return allowed_endpoints
95
 
96
 
 
112
  :param request: Gradio request to evaluate
113
  :return: Button for either login or logout action
114
  """
115
+ user = get_user(request).username
116
  print(f"Getting login button for {user}")
117
 
118
  if user == "guest":
 
121
  return gr.Button(f"Logout {user}", link="/logout")
122
 
123
 
124
+ def get_user(request: Request) -> User:
125
  """
126
  Get a unique user email address for the specified request
127
  :param request: FastAPI Request object with user session data
128
  :return: String user email address or "guest"
129
  """
130
+ # {'iss': 'https://accounts.google.com',
131
+ # 'azp': '***.apps.googleusercontent.com',
132
+ # 'aud': '***.apps.googleusercontent.com',
133
+ # 'sub': '###',
134
+ # 'hd': 'neon.ai',
135
+ # 'email': '[email protected]',
136
+ # 'email_verified': True,
137
+ # 'at_hash': '***',
138
+ # 'nonce': '***',
139
+ # 'name': 'Daniel McKnight',
140
+ # 'picture': 'https://lh3.googleusercontent.com/a/***',
141
+ # 'given_name': '***',
142
+ # 'family_name': '***',
143
+ # 'iat': ###,
144
+ # 'exp': ###}
145
  if not request:
146
+ return User(OAuthProvider.NONE, "guest", "")
147
+
148
+ user_dict = request.session.get("user", {})
149
+ if user_dict.get("iss") == "https://accounts.google.com":
150
+ user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"])
151
+ elif user_dict:
152
+ print(f"Unknown user session data: {user_dict}")
153
+ user = User(OAuthProvider.NONE, "guest", "")
154
+ else:
155
+ user = User(OAuthProvider.NONE, "guest", "")
156
+ print(user)
157
  return user
158
 
159
 
 
261
  # `user` is a valid Google email address or 'guest'
262
  user = get_user(request.request)
263
  else:
264
+ user = User(OAuthProvider.NONE, "guest", "")
265
+ print(f"Getting models for {user.username}")
266
 
267
+ allowed_llm_host_names = get_allowed_models(user)
 
268
 
269
  radio_infos = [f"{name} ({clients[name].vllm_model_name})"
270
+ if name in clients else "Not Authorized"
271
  for name in allowed_llm_host_names]
272
  # Components
273
+ radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [],
274
  value=None, label=info) for name, info
275
  in zip(allowed_llm_host_names, radio_infos)]
276
 
277
  # Select the first available option by default
278
  radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0]
279
  print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}")
280
+ # # Ensure we always have the same number of rows
281
+ # while len(radios) < len(llm_host_names):
282
+ # radios.append(gr.Radio(choices=[], value=None, label="Not Authorized"))
283
  return radios
284
 
285
 
 
300
  @gr.on(triggers=[blocks.load, *[radio.input for radio in radios]],
301
  inputs=[radio_state, *radios], outputs=[radio_state, *radios])
302
  def radio_click(state, *new_state):
303
+ """
304
+ Handle any state changes that require re-rendering radio buttons
305
+ :param state: Previous radio state representation (before selection)
306
+ :param new_state: Current radio state (including selection)
307
+ :return: Desired new state (current option selected, previous option
308
+ deselected)
309
+ """
310
+ # Login and model options are triggered on load. This sleep is just
311
+ # a hack to make sure those events run before this logic to select
312
+ # the default model
313
+ sleep(0.1)
314
  try:
315
  changed_index = next(i for i in range(len(state))
316
  if state[i] != new_state[i])
 
366
  init_config()
367
  init_oauth()
368
  blocks = init_gradio()
369
+ app = gr.mount_gradio_app(app, blocks, '/')
370
  uvicorn.run(app, host='0.0.0.0', port=7860)
shared.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import yaml
2
 
3
  from typing import Dict, Optional, List
@@ -8,6 +11,18 @@ from huggingface_hub.utils import EntryNotFoundError
8
  from openai import OpenAI
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class PileConfig(BaseModel):
12
  file2persona: Dict[str, str]
13
  file2prefix: Dict[str, str]
@@ -17,7 +32,7 @@ class PileConfig(BaseModel):
17
 
18
  class InferenceConfig(BaseModel):
19
  chat_template: str
20
- allowed_domains: Optional[List[str]] = None
21
 
22
 
23
  class RepoConfig(BaseModel):
 
1
+ from dataclasses import dataclass
2
+ from enum import IntEnum
3
+
4
  import yaml
5
 
6
  from typing import Dict, Optional, List
 
11
  from openai import OpenAI
12
 
13
 
14
+ class OAuthProvider(IntEnum):
15
+ NONE = 0
16
+ GOOGLE = 1
17
+
18
+
19
+ @dataclass
20
+ class User:
21
+ oauth: OAuthProvider
22
+ username: str
23
+ permissions_id: str
24
+
25
+
26
  class PileConfig(BaseModel):
27
  file2persona: Dict[str, str]
28
  file2prefix: Dict[str, str]
 
32
 
33
  class InferenceConfig(BaseModel):
34
  chat_template: str
35
+ permissions: Dict[str, list] = {}
36
 
37
 
38
  class RepoConfig(BaseModel):