Title: Disparate Performance between Python and Java · Issue #602 · tensorflow/java · GitHub
Open Graph Title: Disparate Performance between Python and Java · Issue #602 · tensorflow/java
X Title: Disparate Performance between Python and Java · Issue #602 · tensorflow/java
Description: Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template System information Have I written custom code (a...
Open Graph Description: Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template System i...
X Description: Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template System i...
Opengraph URL: https://github.com/tensorflow/java/issues/602
X: @github
Domain: github.com
{"@context":"https://schema.org","@type":"DiscussionForumPosting","headline":"Disparate Performance between Python and Java","articleBody":"\u003cem\u003ePlease make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template\u003c/em\u003e\n\n**System information**\n- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes\n- OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86\\_64): Ubuntu 24.04 x86\\_64\n- TensorFlow installed from (source or binary): binary\n- TensorFlow version (use command below): 1.0.0 \n- Java version (i.e., the output of `java -version`): openjdk version \"21.0.6\" 2025-01-21\n- Java command line flags (e.g., GC parameters): \n- Python version (if transferring a model trained in Python): 3.12.8\n- Bazel version (if compiling from source):\n- GCC/Compiler version (if compiling from source):\n- CUDA/cuDNN version: 12.8.61/8905\n- GPU model and memory: V100 (32GB)\n\n**Describe the current behavior**\n\nExecuting the exported model using Tensorflow in Python takes significantly less time than when calling the same function from using Tensorflow Java. I suspect that I am just not using the Java API correctly, because a small change to the python can lead to comparably poor performance in the python.\n\n**Describe the expected behavior**\n\nThe function calls should take a comparable amount of time.\n\n**Code to reproduce the issue**\n\nI have the following python function:\n\n```python\n@tf.function(\n input_signature=[\n tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name=\"data\"), # [k, n, n]\n tf.TensorSpec(shape=[1, 2048, 2048], dtype=tf.float32, name=\"image\"), # [1, n, n]\n tf.TensorSpec(shape=[41, 2048, 2048], dtype=tf.float32, name=\"psf\"), # [k, n, n]\n ],\n jit_compile=True\n)\ndef rl_step(\n data: tf.Tensor, # [k, n, n]\n image: tf.Tensor, # [1, n, n]\n psf: tf.Tensor, # [k, n, n]\n) -\u003e tf.Tensor: # [k, n, n]\n psf_fft = tf.signal.rfft2d(psf)\n psft_fft = tf.signal.rfft2d(tf.reverse(psf, axis=(-2, -1)))\n denom = tf.reduce_sum(\n tf.signal.irfft2d(psf_fft * tf.signal.rfft2d(data)),\n axis=0,\n keepdims=True\n )\n img_err = image / denom\n return data * tf.signal.irfft2d(tf.signal.rfft2d(img_err) * psft_fft)\n```\n\nIn python, this function is applied iteratively over the same tensor as below:\n\n```python\n image_tensor = tf.constant(image) # [k, n, n]\n measured_psf_tensor = tf.constant(measured_psf) # [1, n, n]\n data_tensor = tf.constant(data) # [k, n, n]\n\n for i in range(10):\n start = time()\n data = rl_step(data_tensor, image_tensor, measured_psf_tensor)\n print(f\"Iter {i}:\", time() - start, \"seconds.\")\n```\nHere `image`, `measured_psf`, and `data` are all 3D arrays with dtype=float32 and `n=2048` and `k=41`\n\nThis prints timings around the following:\n\n```bash\nIter 0: 0.2061774730682373 seconds.\nIter 1: 0.004193544387817383 seconds.\nIter 2: 0.0007469654083251953 seconds.\nIter 3: 0.000415802001953125 seconds.\nIter 4: 0.0004220008850097656 seconds.\nIter 5: 0.0004246234893798828 seconds.\nIter 6: 0.0004112720489501953 seconds.\nIter 7: 0.00042128562927246094 seconds.\nIter 8: 0.0004055500030517578 seconds.\nIter 9: 0.00040721893310546875 seconds.\n```\n\nI tried exporting the model by adding the following after the timing code:\n\n```python\n mod = tf.Module()\n mod.f = rl_step\n tf.saved_model.save(mod, \"pure_tf_export\")\n```\n\nNow I tried to use this exported mode from the Java API,\n\n\n```java\n String modelLocation = \"./pure_tf_export\";\n try(Graph g = new Graph(); Session s = new Session(g)){\n SavedModelBundle model = SavedModelBundle.loader(modelLocation).load();\n\n try (Tensor imageTensor = TFloat32.tensorOf(image);\n Tensor psfTensor = TFloat32.tensorOf(psf);\n Tensor dataTensor = TFloat32.tensorOf(data)\n ){\n Map\u003cString, Tensor\u003e inputs = new HashMap\u003cString, Tensor\u003e();\n inputs.put(\"data\", dataTensor);\n inputs.put(\"image\", imageTensor);\n inputs.put(\"psf\", psfTensor);\n\n for (int i = 0; i \u003c 10; i++){\n\n Instant start = Instant.now();\n\n Result result = model.function(\"serving_default\").call(inputs);\n inputs.replace(\"data\", result.get(\"output_0\").get());\n\n System.out.println(\"Iter \" + i + \" \" + (Duration.between(start, Instant.now()).toMillis()/1000f) + \" seconds\");\n }\n }\n }\n```\nAnd I get timings as follows:\n\n```bash\nIter 0 0.701 seconds\nIter 1 0.528 seconds\nIter 2 0.874 seconds\nIter 3 0.224 seconds\nIter 4 0.254 seconds\nIter 5 1.622 seconds\nIter 6 0.241 seconds\nIter 7 0.224 seconds\nIter 8 0.231 seconds\nIter 9 0.228 seconds\n```\n\nI am pretty sure I am making a simple mistake somewhere. I suspect it is in how I am instantiating the Tensors. I know in python if you don't use `tf.constant` the timings go up a lot.\n\nAny help would be very much appreciated. I tried looking through the documentation and the tensorflow java-examples repository, but couldn't spot what I am doing wrong.\n\nThanks again!","author":{"url":"https://github.com/ryanhausen","@type":"Person","name":"ryanhausen"},"datePublished":"2025-02-11T22:12:09.000Z","interactionStatistic":{"@type":"InteractionCounter","interactionType":"https://schema.org/CommentAction","userInteractionCount":7},"url":"https://github.com/602/java/issues/602"}
| route-pattern | /_view_fragments/issues/show/:user_id/:repository/:id/issue_layout(.:format) |
| route-controller | voltron_issues_fragments |
| route-action | issue_layout |
| fetch-nonce | v2:1455abad-afdf-3c52-d668-38028eeabe97 |
| current-catalog-service-hash | 81bb79d38c15960b92d99bca9288a9108c7a47b18f2423d0f6438c5b7bcd2114 |
| request-id | 9D86:296AF5:217CCF:2EC670:696A6567 |
| html-safe-nonce | 7815b459b763eddf469c62a068898aec7d9a4382b5d0d6d95c507fdc3cc8707e |
| visitor-payload | eyJyZWZlcnJlciI6IiIsInJlcXVlc3RfaWQiOiI5RDg2OjI5NkFGNToyMTdDQ0Y6MkVDNjcwOjY5NkE2NTY3IiwidmlzaXRvcl9pZCI6IjQ0NjQxNjIzMjQ1OTkxNzA0MDciLCJyZWdpb25fZWRnZSI6ImlhZCIsInJlZ2lvbl9yZW5kZXIiOiJpYWQifQ== |
| visitor-hmac | 30a51b110a607af9b50779582a0626f8dfc81a636219945e880d218cae6c2709 |
| hovercard-subject-tag | issue:2846652863 |
| github-keyboard-shortcuts | repository,issues,copilot |
| google-site-verification | Apib7-x98H0j5cPqHWwSMm6dNU4GmODRoqxLiDzdx9I |
| octolytics-url | https://collector.github.com/github/collect |
| analytics-location | / |
| fb:app_id | 1401488693436528 |
| apple-itunes-app | app-id=1477376905, app-argument=https://github.com/_view_fragments/issues/show/tensorflow/java/602/issue_layout |
| twitter:image | https://opengraph.githubassets.com/bec87f1a0870a144941e9677a0c9046f19fcfd52ebf83fa1608d03d4abdb5479/tensorflow/java/issues/602 |
| twitter:card | summary_large_image |
| og:image | https://opengraph.githubassets.com/bec87f1a0870a144941e9677a0c9046f19fcfd52ebf83fa1608d03d4abdb5479/tensorflow/java/issues/602 |
| og:image:alt | Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template System i... |
| og:image:width | 1200 |
| og:image:height | 600 |
| og:site_name | GitHub |
| og:type | object |
| og:author:username | ryanhausen |
| hostname | github.com |
| expected-hostname | github.com |
| None | 9b7735a184970dd9333b2cbe036c8f3c0a9108c64aaa93827c5a64fc70993392 |
| turbo-cache-control | no-preview |
| go-import | github.com/tensorflow/java git https://github.com/tensorflow/java.git |
| octolytics-dimension-user_id | 15658638 |
| octolytics-dimension-user_login | tensorflow |
| octolytics-dimension-repository_id | 207384523 |
| octolytics-dimension-repository_nwo | tensorflow/java |
| octolytics-dimension-repository_public | true |
| octolytics-dimension-repository_is_fork | false |
| octolytics-dimension-repository_network_root_id | 207384523 |
| octolytics-dimension-repository_network_root_nwo | tensorflow/java |
| turbo-body-classes | logged-out env-production page-responsive |
| disable-turbo | false |
| browser-stats-url | https://api.github.com/_private/browser/stats |
| browser-errors-url | https://api.github.com/_private/browser/errors |
| release | 87cbd411c2982752221b5751d583a515b23bf5fa |
| ui-target | full |
| theme-color | #1e2327 |
| color-scheme | light dark |
Links:
Viewport: width=device-width