Source code for tensorcircuit.cloud.utils
"""
utility functions for cloud connection
"""
from typing import Any, Callable, Optional
from functools import wraps
import inspect
import logging
import os
import sys
import time
import requests
# from simplejson.errors import JSONDecodeError
logger = logging.getLogger(__name__)
thismodule = sys.modules[__name__]
[docs]
class HttpStatusError(Exception):
"""
Used when the return request has http code beyond 200
"""
pass
# TODO(@refraction-ray): whether an exception hierarchy for tc is necessary?
connection_errors = (
ConnectionResetError,
HttpStatusError,
requests.exceptions.RequestException,
requests.exceptions.ConnectionError,
requests.exceptions.SSLError,
ValueError,
# JSONDecodeError,
)
[docs]
def set_proxy(proxy: Optional[str] = None) -> None:
"""
:param proxy: str. format as "http://user:passwd@host:port" user passwd part can be omitted if not set.
None for turning off the proxy.
:return:
"""
if proxy:
os.environ["http_proxy"] = proxy
os.environ["https_proxy"] = proxy
setattr(thismodule, "proxy", proxy)
else:
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
setattr(thismodule, "proxy", None)
[docs]
def reconnect(tries: int = 5, timeout: int = 12) -> Callable[..., Any]:
# wrapper originally designed in xalpha by @refraction-ray
# https://github.com/refraction-ray/xalpha
def robustify(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kws: Any) -> Any:
if getattr(thismodule, "proxy", None):
kws["proxies"] = {
"http": getattr(thismodule, "proxy"),
"https": getattr(thismodule, "proxy"),
}
logger.debug("Using proxy %s" % getattr(thismodule, "proxy"))
kws["timeout"] = timeout
if args:
url = args[0]
else:
url = kws.get("url", "")
headers = kws.get("headers", {})
if (not headers.get("user-agent", None)) and (
not headers.get("User-Agent", None)
):
headers["user-agent"] = "Mozilla/5.0"
kws["headers"] = headers
for count in range(tries):
try:
logger.debug(
"Fetching url: %s . Inside function `%s`"
% (url, inspect.stack()[1].function)
)
r = f(*args, **kws)
if (
getattr(r, "status_code", 200) != 200
): # in case r is a json dict
raise HttpStatusError
return r
except connection_errors as e:
logger.warning("Fails at fetching url: %s. Try again." % url)
if count == tries - 1:
logger.error(
"Still wrong at fetching url: %s. after %s tries."
% (url, tries)
)
logger.error("Fails due to %s" % e.args[0])
raise e
time.sleep(0.5 * count)
return wrapper
return robustify
rget = reconnect()(requests.get)
rpost = reconnect()(requests.post)
[docs]
@reconnect()
def rget_json(*args: Any, **kws: Any) -> Any:
r = requests.get(*args, **kws)
return r.json()
[docs]
@reconnect()
def rpost_json(*args: Any, **kws: Any) -> Any:
r = requests.post(*args, **kws)
return r.json()