You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.5 KiB
108 lines
3.5 KiB
import time, atexit, uuid
|
|
from enum import Enum
|
|
|
|
from tinygrad.device import Device
|
|
from tinygrad.helpers import DEBUG, ContextVar, getenv, GlobalCounters
|
|
|
|
BENCHMARK_LOG = ContextVar("BENCHMARK_LOG", "")
|
|
|
|
if BENCHMARK_LOG:
|
|
from influxdb_client_3 import InfluxDBClient3, Point, WriteOptions, write_client_options
|
|
from influxdb_client_3.write_client.client.write_api import WriteType
|
|
|
|
class BenchEvent(Enum):
|
|
LOAD_WEIGHTS = "load_weights"
|
|
STEP = "step"
|
|
FULL = "full"
|
|
class InstantBenchEvent(Enum):
|
|
GFLOPS = "gflops"
|
|
|
|
_events = {}
|
|
def clear_events():
|
|
for event in BenchEvent:
|
|
_events[event] = {"wall": [], "kernel": []}
|
|
for event in InstantBenchEvent:
|
|
_events[event] = []
|
|
clear_events()
|
|
|
|
class WallTimeEvent:
|
|
def __init__(self, event:BenchEvent):
|
|
self.event = event
|
|
def __enter__(self):
|
|
self.start = time.monotonic()
|
|
return self
|
|
def __exit__(self, *_):
|
|
_events[self.event]["wall"].append(time.monotonic() - self.start)
|
|
return False
|
|
|
|
class KernelTimeEvent:
|
|
def __init__(self, event:BenchEvent):
|
|
if DEBUG < 2:
|
|
raise Exception("KernelTimeEvent should only be used in DEBUG >= 2")
|
|
self.event = event
|
|
def __enter__(self):
|
|
self.start = GlobalCounters.time_sum_s
|
|
return self
|
|
def __exit__(self, *_):
|
|
_events[self.event]["kernel"].append(GlobalCounters.time_sum_s - self.start)
|
|
return False
|
|
|
|
def log_event_instant(event:InstantBenchEvent, value:float):
|
|
_events[event].append(value)
|
|
|
|
if BENCHMARK_LOG:
|
|
INFLUXDB_HOST = getenv("INFLUXDB_HOST", "")
|
|
INFLUXDB_ORG = getenv("INFLUXDB_ORG", "tiny")
|
|
INFLUXDB_TOKEN = getenv("INFLUXDB_TOKEN", "")
|
|
|
|
def _create_point(run_id, i, attempt, ref, commit, name, value, run):
|
|
point = Point(BENCHMARK_LOG.value).tag("id", run_id).tag("index", i)
|
|
point = point.tag("device", Device.DEFAULT)
|
|
point = point.tag("attempt", attempt).tag("ref", ref).tag("commit", commit)
|
|
point = point.field(name, value).field("x", run)
|
|
return point
|
|
|
|
@atexit.register
|
|
def write_events():
|
|
# see if there are any events to write
|
|
have_events = False
|
|
for event in _events:
|
|
if isinstance(event, BenchEvent):
|
|
for event_type, values in _events[event].items():
|
|
if len(values) > 0:
|
|
have_events = True
|
|
else:
|
|
if len(_events[event]) > 0:
|
|
have_events = True
|
|
if not have_events:
|
|
return
|
|
|
|
# pull from github envvars
|
|
ref = getenv("GITHUB_REF_NAME", "")
|
|
commit = getenv("GITHUB_SHA", "")
|
|
run = getenv("GITHUB_RUN_NUMBER", "")
|
|
attempt = getenv("GITHUB_RUN_ATTEMPT", "")
|
|
|
|
points = []
|
|
for event in _events:
|
|
run_id = str(uuid.uuid4())
|
|
if isinstance(event, BenchEvent):
|
|
for event_type, values in _events[event].items():
|
|
for i, value in enumerate(values):
|
|
point = _create_point(run_id, i, attempt, ref, commit, f"{event.value}_{event_type}", value, run)
|
|
points.append(point)
|
|
else:
|
|
for i, value in enumerate(_events[event]):
|
|
point = _create_point(run_id, i, attempt, ref, commit, event.value, value, run)
|
|
points.append(point)
|
|
|
|
write_options = WriteOptions(write_type=WriteType.synchronous, retry_interval=5000, max_retries=5, max_retry_delay=30000, exponential_base=2)
|
|
wco = write_client_options(write_options=write_options)
|
|
with InfluxDBClient3(
|
|
host=INFLUXDB_HOST,
|
|
org=INFLUXDB_ORG,
|
|
token=INFLUXDB_TOKEN,
|
|
auth_scheme="Basic",
|
|
database="benchmarks",
|
|
write_client_options=wco) as client:
|
|
client.write(points)
|
|
|