You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							50 lines
						
					
					
						
							1.3 KiB
						
					
					
				
			
		
		
	
	
							50 lines
						
					
					
						
							1.3 KiB
						
					
					
				#!/usr/bin/env python3
 | 
						|
# TODO: why are the keras models saved with python 2?
 | 
						|
from __future__ import print_function
 | 
						|
 | 
						|
import tensorflow as tf  # pylint: disable=import-error
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import numpy as np
 | 
						|
from tensorflow.keras.models import load_model  # pylint: disable=import-error
 | 
						|
 | 
						|
def read(sz):
 | 
						|
  dd = []
 | 
						|
  gt = 0
 | 
						|
  while gt < sz * 4:
 | 
						|
    st = os.read(0, sz * 4 - gt)
 | 
						|
    assert(len(st) > 0)
 | 
						|
    dd.append(st)
 | 
						|
    gt += len(st)
 | 
						|
  return np.frombuffer(b''.join(dd), dtype=np.float32)
 | 
						|
 | 
						|
def write(d):
 | 
						|
  os.write(1, d.tobytes())
 | 
						|
 | 
						|
def run_loop(m):
 | 
						|
  ishapes = [[1]+ii.shape[1:] for ii in m.inputs]
 | 
						|
  print("ready to run keras model", ishapes, file=sys.stderr)
 | 
						|
  while 1:
 | 
						|
    inputs = []
 | 
						|
    for shp in ishapes:
 | 
						|
      ts = np.product(shp)
 | 
						|
      #print("reshaping %s with offset %d" % (str(shp), offset), file=sys.stderr)
 | 
						|
      inputs.append(read(ts).reshape(shp))
 | 
						|
    ret = m.predict_on_batch(inputs)
 | 
						|
    #print(ret, file=sys.stderr)
 | 
						|
    for r in ret:
 | 
						|
      write(r)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
  print(tf.__version__, file=sys.stderr)
 | 
						|
  # limit gram alloc
 | 
						|
  gpus = tf.config.experimental.list_physical_devices('GPU')
 | 
						|
  if len(gpus) > 0:
 | 
						|
    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
 | 
						|
 | 
						|
  m = load_model(sys.argv[1])
 | 
						|
  print(m, file=sys.stderr)
 | 
						|
 | 
						|
  run_loop(m)
 | 
						|
 | 
						|
 |