diff --git a/airsim/__init__.py b/airsim/__init__.py new file mode 100644 index 0000000..17f3aca --- /dev/null +++ b/airsim/__init__.py @@ -0,0 +1,5 @@ +from .client import * +from .utils import * +from .types import * + +__version__ = "1.8.1" diff --git a/airsim/client.py b/airsim/client.py new file mode 100644 index 0000000..f9cfd22 --- /dev/null +++ b/airsim/client.py @@ -0,0 +1,1631 @@ +from __future__ import print_function + +from .utils import * +from .types import * + +import msgpackrpc #install as admin: pip install msgpack-rpc-python +import numpy as np #pip install numpy +import msgpack +import time +import math +import logging + +class VehicleClient: + def __init__(self, ip = "", port = 41451, timeout_value = 3600): + if (ip == ""): + ip = "127.0.0.1" + self.client = msgpackrpc.Client(msgpackrpc.Address(ip, port), timeout = timeout_value, pack_encoding = 'utf-8', unpack_encoding = 'utf-8') + +#----------------------------------- Common vehicle APIs --------------------------------------------- + def reset(self): + """ + Reset the vehicle to its original starting state + + Note that you must call `enableApiControl` and `armDisarm` again after the call to reset + """ + self.client.call('reset') + + def ping(self): + """ + If connection is established then this call will return true otherwise it will be blocked until timeout + + Returns: + bool: + """ + return self.client.call('ping') + + def getClientVersion(self): + return 1 # sync with C++ client + + def getServerVersion(self): + return self.client.call('getServerVersion') + + def getMinRequiredServerVersion(self): + return 1 # sync with C++ client + + def getMinRequiredClientVersion(self): + return self.client.call('getMinRequiredClientVersion') + +#basic flight control + def enableApiControl(self, is_enabled, vehicle_name = ''): + """ + Enables or disables API control for vehicle corresponding to vehicle_name + + Args: + is_enabled (bool): True to enable, False to disable API control + vehicle_name (str, optional): Name of the vehicle to send this command to + """ + self.client.call('enableApiControl', is_enabled, vehicle_name) + + def isApiControlEnabled(self, vehicle_name = ''): + """ + Returns true if API control is established. + + If false (which is default) then API calls would be ignored. After a successful call to `enableApiControl`, `isApiControlEnabled` should return true. + + Args: + vehicle_name (str, optional): Name of the vehicle + + Returns: + bool: If API control is enabled + """ + return self.client.call('isApiControlEnabled', vehicle_name) + + def armDisarm(self, arm, vehicle_name = ''): + """ + Arms or disarms vehicle + + Args: + arm (bool): True to arm, False to disarm the vehicle + vehicle_name (str, optional): Name of the vehicle to send this command to + + Returns: + bool: Success + """ + return self.client.call('armDisarm', arm, vehicle_name) + + def simPause(self, is_paused): + """ + Pauses simulation + + Args: + is_paused (bool): True to pause the simulation, False to release + """ + self.client.call('simPause', is_paused) + + def simIsPause(self): + """ + Returns true if the simulation is paused + + Returns: + bool: If the simulation is paused + """ + return self.client.call("simIsPaused") + + def simContinueForTime(self, seconds): + """ + Continue the simulation for the specified number of seconds + + Args: + seconds (float): Time to run the simulation for + """ + self.client.call('simContinueForTime', seconds) + + def simContinueForFrames(self, frames): + """ + Continue (or resume if paused) the simulation for the specified number of frames, after which the simulation will be paused. + + Args: + frames (int): Frames to run the simulation for + """ + self.client.call('simContinueForFrames', frames) + + def getHomeGeoPoint(self, vehicle_name = ''): + """ + Get the Home location of the vehicle + + Args: + vehicle_name (str, optional): Name of vehicle to get home location of + + Returns: + GeoPoint: Home location of the vehicle + """ + return GeoPoint.from_msgpack(self.client.call('getHomeGeoPoint', vehicle_name)) + + def confirmConnection(self): + """ + Checks state of connection every 1 sec and reports it in Console so user can see the progress for connection. + """ + if self.ping(): + print("Connected!") + else: + print("Ping returned false!") + server_ver = self.getServerVersion() + client_ver = self.getClientVersion() + server_min_ver = self.getMinRequiredServerVersion() + client_min_ver = self.getMinRequiredClientVersion() + + ver_info = "Client Ver:" + str(client_ver) + " (Min Req: " + str(client_min_ver) + \ + "), Server Ver:" + str(server_ver) + " (Min Req: " + str(server_min_ver) + ")" + + if server_ver < server_min_ver: + print(ver_info, file=sys.stderr) + print("AirSim server is of older version and not supported by this client. Please upgrade!") + elif client_ver < client_min_ver: + print(ver_info, file=sys.stderr) + print("AirSim client is of older version and not supported by this server. Please upgrade!") + else: + print(ver_info) + print('') + + def simSetLightIntensity(self, light_name, intensity): + """ + Change intensity of named light + + Args: + light_name (str): Name of light to change + intensity (float): New intensity value + + Returns: + bool: True if successful, otherwise False + """ + return self.client.call("simSetLightIntensity", light_name, intensity) + + def simSwapTextures(self, tags, tex_id = 0, component_id = 0, material_id = 0): + """ + Runtime Swap Texture API + + See https://microsoft.github.io/AirSim/retexturing/ for details + + Args: + tags (str): string of "," or ", " delimited tags to identify on which actors to perform the swap + tex_id (int, optional): indexes the array of textures assigned to each actor undergoing a swap + + If out-of-bounds for some object's texture set, it will be taken modulo the number of textures that were available + component_id (int, optional): + material_id (int, optional): + + Returns: + list[str]: List of objects which matched the provided tags and had the texture swap perfomed + """ + return self.client.call("simSwapTextures", tags, tex_id, component_id, material_id) + + def simSetObjectMaterial(self, object_name, material_name, component_id = 0): + """ + Runtime Swap Texture API + See https://microsoft.github.io/AirSim/retexturing/ for details + Args: + object_name (str): name of object to set material for + material_name (str): name of material to set for object + component_id (int, optional) : index of material elements + + Returns: + bool: True if material was set + """ + return self.client.call("simSetObjectMaterial", object_name, material_name, component_id) + + def simSetObjectMaterialFromTexture(self, object_name, texture_path, component_id = 0): + """ + Runtime Swap Texture API + See https://microsoft.github.io/AirSim/retexturing/ for details + Args: + object_name (str): name of object to set material for + texture_path (str): path to texture to set for object + component_id (int, optional) : index of material elements + + Returns: + bool: True if material was set + """ + return self.client.call("simSetObjectMaterialFromTexture", object_name, texture_path, component_id) + + + # time-of-day control +#time - of - day control + def simSetTimeOfDay(self, is_enabled, start_datetime = "", is_start_datetime_dst = False, celestial_clock_speed = 1, update_interval_secs = 60, move_sun = True): + """ + Control the position of Sun in the environment + + Sun's position is computed using the coordinates specified in `OriginGeopoint` in settings for the date-time specified in the argument, + else if the string is empty, current date & time is used + + Args: + is_enabled (bool): True to enable time-of-day effect, False to reset the position to original + start_datetime (str, optional): Date & Time in %Y-%m-%d %H:%M:%S format, e.g. `2018-02-12 15:20:00` + is_start_datetime_dst (bool, optional): True to adjust for Daylight Savings Time + celestial_clock_speed (float, optional): Run celestial clock faster or slower than simulation clock + E.g. Value 100 means for every 1 second of simulation clock, Sun's position is advanced by 100 seconds + so Sun will move in sky much faster + update_interval_secs (float, optional): Interval to update the Sun's position + move_sun (bool, optional): Whether or not to move the Sun + """ + self.client.call('simSetTimeOfDay', is_enabled, start_datetime, is_start_datetime_dst, celestial_clock_speed, update_interval_secs, move_sun) + +#weather + def simEnableWeather(self, enable): + """ + Enable Weather effects. Needs to be called before using `simSetWeatherParameter` API + + Args: + enable (bool): True to enable, False to disable + """ + self.client.call('simEnableWeather', enable) + + def simSetWeatherParameter(self, param, val): + """ + Enable various weather effects + + Args: + param (WeatherParameter): Weather effect to be enabled + val (float): Intensity of the effect, Range 0-1 + """ + self.client.call('simSetWeatherParameter', param, val) + +#camera control +#simGetImage returns compressed png in array of bytes +#image_type uses one of the ImageType members + def simGetImage(self, camera_name, image_type, vehicle_name = '', external = False): + """ + Get a single image + + Returns bytes of png format image which can be dumped into abinary file to create .png image + `string_to_uint8_array()` can be used to convert into Numpy unit8 array + See https://microsoft.github.io/AirSim/image_apis/ for details + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + image_type (ImageType): Type of image required + vehicle_name (str, optional): Name of the vehicle with the camera + external (bool, optional): Whether the camera is an External Camera + + Returns: + Binary string literal of compressed png image + """ +#todo : in future remove below, it's only for compatibility to pre v1.2 + camera_name = str(camera_name) + +#because this method returns std::vector < uint8>, msgpack decides to encode it as a string unfortunately. + result = self.client.call('simGetImage', camera_name, image_type, vehicle_name, external) + if (result == "" or result == "\0"): + return None + return result + +#camera control +#simGetImage returns compressed png in array of bytes +#image_type uses one of the ImageType members + def simGetImages(self, requests, vehicle_name = '', external = False): + """ + Get multiple images + + See https://microsoft.github.io/AirSim/image_apis/ for details and examples + + Args: + requests (list[ImageRequest]): Images required + vehicle_name (str, optional): Name of vehicle associated with the camera + external (bool, optional): Whether the camera is an External Camera + + Returns: + list[ImageResponse]: + """ + responses_raw = self.client.call('simGetImages', requests, vehicle_name, external) + return [ImageResponse.from_msgpack(response_raw) for response_raw in responses_raw] + + + +#CinemAirSim + def simGetPresetLensSettings(self, camera_name, vehicle_name = '', external = False): + result = self.client.call('simGetPresetLensSettings', camera_name, vehicle_name, external) + if (result == "" or result == "\0"): + return None + return result + + def simGetLensSettings(self, camera_name, vehicle_name = '', external = False): + result = self.client.call('simGetLensSettings', camera_name, vehicle_name, external) + if (result == "" or result == "\0"): + return None + return result + + def simSetPresetLensSettings(self, preset_lens_settings, camera_name, vehicle_name = '', external = False): + self.client.call("simSetPresetLensSettings", preset_lens_settings, camera_name, vehicle_name, external) + + def simGetPresetFilmbackSettings(self, camera_name, vehicle_name = '', external = False): + result = self.client.call('simGetPresetFilmbackSettings', camera_name, vehicle_name, external) + if (result == "" or result == "\0"): + return None + return result + + def simSetPresetFilmbackSettings(self, preset_filmback_settings, camera_name, vehicle_name = '', external = False): + self.client.call("simSetPresetFilmbackSettings", preset_filmback_settings, camera_name, vehicle_name, external) + + def simGetFilmbackSettings(self, camera_name, vehicle_name = '', external = False): + result = self.client.call('simGetFilmbackSettings', camera_name, vehicle_name, external) + if (result == "" or result == "\0"): + return None + return result + + def simSetFilmbackSettings(self, sensor_width, sensor_height, camera_name, vehicle_name = '', external = False): + return self.client.call("simSetFilmbackSettings", sensor_width, sensor_height, camera_name, vehicle_name, external) + + def simGetFocalLength(self, camera_name, vehicle_name = '', external = False): + return self.client.call("simGetFocalLength", camera_name, vehicle_name, external) + + def simSetFocalLength(self, focal_length, camera_name, vehicle_name = '', external = False): + self.client.call("simSetFocalLength", focal_length, camera_name, vehicle_name, external) + + def simEnableManualFocus(self, enable, camera_name, vehicle_name = '', external = False): + self.client.call("simEnableManualFocus", enable, camera_name, vehicle_name, external) + + def simGetFocusDistance(self, camera_name, vehicle_name = '', external = False): + return self.client.call("simGetFocusDistance", camera_name, vehicle_name, external) + + def simSetFocusDistance(self, focus_distance, camera_name, vehicle_name = '', external = False): + self.client.call("simSetFocusDistance", focus_distance, camera_name, vehicle_name, external) + + def simGetFocusAperture(self, camera_name, vehicle_name = '', external = False): + return self.client.call("simGetFocusAperture", camera_name, vehicle_name, external) + + def simSetFocusAperture(self, focus_aperture, camera_name, vehicle_name = '', external = False): + self.client.call("simSetFocusAperture", focus_aperture, camera_name, vehicle_name, external) + + def simEnableFocusPlane(self, enable, camera_name, vehicle_name = '', external = False): + self.client.call("simEnableFocusPlane", enable, camera_name, vehicle_name, external) + + def simGetCurrentFieldOfView(self, camera_name, vehicle_name = '', external = False): + return self.client.call("simGetCurrentFieldOfView", camera_name, vehicle_name, external) + +#End CinemAirSim + def simTestLineOfSightToPoint(self, point, vehicle_name = ''): + """ + Returns whether the target point is visible from the perspective of the inputted vehicle + + Args: + point (GeoPoint): target point + vehicle_name (str, optional): Name of vehicle + + Returns: + [bool]: Success + """ + return self.client.call('simTestLineOfSightToPoint', point, vehicle_name) + + def simTestLineOfSightBetweenPoints(self, point1, point2): + """ + Returns whether the target point is visible from the perspective of the source point + + Args: + point1 (GeoPoint): source point + point2 (GeoPoint): target point + + Returns: + [bool]: Success + """ + return self.client.call('simTestLineOfSightBetweenPoints', point1, point2) + + def simGetWorldExtents(self): + """ + Returns a list of GeoPoints representing the minimum and maximum extents of the world + + Returns: + list[GeoPoint] + """ + responses_raw = self.client.call('simGetWorldExtents') + return [GeoPoint.from_msgpack(response_raw) for response_raw in responses_raw] + + def simRunConsoleCommand(self, command): + """ + Allows the client to execute a command in Unreal's native console, via an API. + Affords access to the countless built-in commands such as "stat unit", "stat fps", "open [map]", adjust any config settings, etc. etc. + Allows the user to create bespoke APIs very easily, by adding a custom event to the level blueprint, and then calling the console command "ce MyEventName [args]". No recompilation of AirSim needed! + + Args: + command ([string]): Desired Unreal Engine Console command to run + + Returns: + [bool]: Success + """ + return self.client.call('simRunConsoleCommand', command) + +#gets the static meshes in the unreal scene + def simGetMeshPositionVertexBuffers(self): + """ + Returns the static meshes that make up the scene + + See https://microsoft.github.io/AirSim/meshes/ for details and how to use this + + Returns: + list[MeshPositionVertexBuffersResponse]: + """ + responses_raw = self.client.call('simGetMeshPositionVertexBuffers') + return [MeshPositionVertexBuffersResponse.from_msgpack(response_raw) for response_raw in responses_raw] + + def simGetCollisionInfo(self, vehicle_name = ''): + """ + Args: + vehicle_name (str, optional): Name of the Vehicle to get the info of + + Returns: + CollisionInfo: + """ + return CollisionInfo.from_msgpack(self.client.call('simGetCollisionInfo', vehicle_name)) + + def simSetVehiclePose(self, pose, ignore_collision, vehicle_name = ''): + """ + Set the pose of the vehicle + + If you don't want to change position (or orientation) then just set components of position (or orientation) to floating point nan values + + Args: + pose (Pose): Desired Pose pf the vehicle + ignore_collision (bool): Whether to ignore any collision or not + vehicle_name (str, optional): Name of the vehicle to move + """ + self.client.call('simSetVehiclePose', pose, ignore_collision, vehicle_name) + + def simGetVehiclePose(self, vehicle_name = ''): + """ + The position inside the returned Pose is in the frame of the vehicle's starting point + + Args: + vehicle_name (str, optional): Name of the vehicle to get the Pose of + + Returns: + Pose: + """ + pose = self.client.call('simGetVehiclePose', vehicle_name) + return Pose.from_msgpack(pose) + + def simSetTraceLine(self, color_rgba, thickness=1.0, vehicle_name = ''): + """ + Modify the color and thickness of the line when Tracing is enabled + + Tracing can be enabled by pressing T in the Editor or setting `EnableTrace` to `True` in the Vehicle Settings + + Args: + color_rgba (list): desired RGBA values from 0.0 to 1.0 + thickness (float, optional): Thickness of the line + vehicle_name (string, optional): Name of the vehicle to set Trace line values for + """ + self.client.call('simSetTraceLine', color_rgba, thickness, vehicle_name) + + def simGetObjectPose(self, object_name): + """ + The position inside the returned Pose is in the world frame + + Args: + object_name (str): Object to get the Pose of + + Returns: + Pose: + """ + pose = self.client.call('simGetObjectPose', object_name) + return Pose.from_msgpack(pose) + + def simSetObjectPose(self, object_name, pose, teleport = True): + """ + Set the pose of the object(actor) in the environment + + The specified actor must have Mobility set to movable, otherwise there will be undefined behaviour. + See https://www.unrealengine.com/en-US/blog/moving-physical-objects for details on how to set Mobility and the effect of Teleport parameter + + Args: + object_name (str): Name of the object(actor) to move + pose (Pose): Desired Pose of the object + teleport (bool, optional): Whether to move the object immediately without affecting their velocity + + Returns: + bool: If the move was successful + """ + return self.client.call('simSetObjectPose', object_name, pose, teleport) + + def simGetObjectScale(self, object_name): + """ + Gets scale of an object in the world + + Args: + object_name (str): Object to get the scale of + + Returns: + airsim.Vector3r: Scale + """ + scale = self.client.call('simGetObjectScale', object_name) + return Vector3r.from_msgpack(scale) + + def simSetObjectScale(self, object_name, scale_vector): + """ + Sets scale of an object in the world + + Args: + object_name (str): Object to set the scale of + scale_vector (airsim.Vector3r): Desired scale of object + + Returns: + bool: True if scale change was successful + """ + return self.client.call('simSetObjectScale', object_name, scale_vector) + + def simListSceneObjects(self, name_regex = '.*'): + """ + Lists the objects present in the environment + + Default behaviour is to list all objects, regex can be used to return smaller list of matching objects or actors + + Args: + name_regex (str, optional): String to match actor names against, e.g. "Cylinder.*" + + Returns: + list[str]: List containing all the names + """ + return self.client.call('simListSceneObjects', name_regex) + + def simLoadLevel(self, level_name): + """ + Loads a level specified by its name + + Args: + level_name (str): Name of the level to load + + Returns: + bool: True if the level was successfully loaded + """ + return self.client.call('simLoadLevel', level_name) + + def simListAssets(self): + """ + Lists all the assets present in the Asset Registry + + Returns: + list[str]: Names of all the assets + """ + return self.client.call('simListAssets') + + def simSpawnObject(self, object_name, asset_name, pose, scale, physics_enabled=False, is_blueprint=False): + """Spawned selected object in the world + + Args: + object_name (str): Desired name of new object + asset_name (str): Name of asset(mesh) in the project database + pose (airsim.Pose): Desired pose of object + scale (airsim.Vector3r): Desired scale of object + physics_enabled (bool, optional): Whether to enable physics for the object + is_blueprint (bool, optional): Whether to spawn a blueprint or an actor + + Returns: + str: Name of spawned object, in case it had to be modified + """ + return self.client.call('simSpawnObject', object_name, asset_name, pose, scale, physics_enabled, is_blueprint) + + def simDestroyObject(self, object_name): + """Removes selected object from the world + + Args: + object_name (str): Name of object to be removed + + Returns: + bool: True if object is queued up for removal + """ + return self.client.call('simDestroyObject', object_name) + + def simSetSegmentationObjectID(self, mesh_name, object_id, is_name_regex = False): + """ + Set segmentation ID for specific objects + + See https://microsoft.github.io/AirSim/image_apis/#segmentation for details + + Args: + mesh_name (str): Name of the mesh to set the ID of (supports regex) + object_id (int): Object ID to be set, range 0-255 + + RBG values for IDs can be seen at https://microsoft.github.io/AirSim/seg_rgbs.txt + is_name_regex (bool, optional): Whether the mesh name is a regex + + Returns: + bool: If the mesh was found + """ + return self.client.call('simSetSegmentationObjectID', mesh_name, object_id, is_name_regex) + + def simGetSegmentationObjectID(self, mesh_name): + """ + Returns Object ID for the given mesh name + + Mapping of Object IDs to RGB values can be seen at https://microsoft.github.io/AirSim/seg_rgbs.txt + + Args: + mesh_name (str): Name of the mesh to get the ID of + """ + return self.client.call('simGetSegmentationObjectID', mesh_name) + + def simAddDetectionFilterMeshName(self, camera_name, image_type, mesh_name, vehicle_name = '', external = False): + """ + Add mesh name to detect in wild card format + + For example: simAddDetectionFilterMeshName("Car_*") will detect all instance named "Car_*" + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + image_type (ImageType): Type of image required + mesh_name (str): mesh name in wild card format + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + + """ + self.client.call('simAddDetectionFilterMeshName', camera_name, image_type, mesh_name, vehicle_name, external) + + def simSetDetectionFilterRadius(self, camera_name, image_type, radius_cm, vehicle_name = '', external = False): + """ + Set detection radius for all cameras + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + image_type (ImageType): Type of image required + radius_cm (int): Radius in [cm] + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + """ + self.client.call('simSetDetectionFilterRadius', camera_name, image_type, radius_cm, vehicle_name, external) + + def simClearDetectionMeshNames(self, camera_name, image_type, vehicle_name = '', external = False): + """ + Clear all mesh names from detection filter + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + image_type (ImageType): Type of image required + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + + """ + self.client.call('simClearDetectionMeshNames', camera_name, image_type, vehicle_name, external) + + def simGetDetections(self, camera_name, image_type, vehicle_name = '', external = False): + """ + Get current detections + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + image_type (ImageType): Type of image required + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + + Returns: + DetectionInfo array + """ + responses_raw = self.client.call('simGetDetections', camera_name, image_type, vehicle_name, external) + return [DetectionInfo.from_msgpack(response_raw) for response_raw in responses_raw] + + def simPrintLogMessage(self, message, message_param = "", severity = 0): + """ + Prints the specified message in the simulator's window. + + If message_param is supplied, then it's printed next to the message and in that case if this API is called with same message value + but different message_param again then previous line is overwritten with new line (instead of API creating new line on display). + + For example, `simPrintLogMessage("Iteration: ", to_string(i))` keeps updating same line on display when API is called with different values of i. + The valid values of severity parameter is 0 to 3 inclusive that corresponds to different colors. + + Args: + message (str): Message to be printed + message_param (str, optional): Parameter to be printed next to the message + severity (int, optional): Range 0-3, inclusive, corresponding to the severity of the message + """ + self.client.call('simPrintLogMessage', message, message_param, severity) + + def simGetCameraInfo(self, camera_name, vehicle_name = '', external=False): + """ + Get details about the camera + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + + Returns: + CameraInfo: + """ +#TODO : below str() conversion is only needed for legacy reason and should be removed in future + return CameraInfo.from_msgpack(self.client.call('simGetCameraInfo', str(camera_name), vehicle_name, external)) + + def simGetDistortionParams(self, camera_name, vehicle_name = '', external = False): + """ + Get camera distortion parameters + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + + Returns: + List (float): List of distortion parameter values corresponding to K1, K2, K3, P1, P2 respectively. + """ + + return self.client.call('simGetDistortionParams', str(camera_name), vehicle_name, external) + + def simSetDistortionParams(self, camera_name, distortion_params, vehicle_name = '', external = False): + """ + Set camera distortion parameters + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + distortion_params (dict): Dictionary of distortion param names and corresponding values + {"K1": 0.0, "K2": 0.0, "K3": 0.0, "P1": 0.0, "P2": 0.0} + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + """ + + for param_name, value in distortion_params.items(): + self.simSetDistortionParam(camera_name, param_name, value, vehicle_name, external) + + def simSetDistortionParam(self, camera_name, param_name, value, vehicle_name = '', external = False): + """ + Set single camera distortion parameter + + Args: + camera_name (str): Name of the camera, for backwards compatibility, ID numbers such as 0,1,etc. can also be used + param_name (str): Name of distortion parameter + value (float): Value of distortion parameter + vehicle_name (str, optional): Vehicle which the camera is associated with + external (bool, optional): Whether the camera is an External Camera + """ + self.client.call('simSetDistortionParam', str(camera_name), param_name, value, vehicle_name, external) + + def simSetCameraPose(self, camera_name, pose, vehicle_name = '', external = False): + """ + - Control the pose of a selected camera + + Args: + camera_name (str): Name of the camera to be controlled + pose (Pose): Pose representing the desired position and orientation of the camera + vehicle_name (str, optional): Name of vehicle which the camera corresponds to + external (bool, optional): Whether the camera is an External Camera + """ +#TODO : below str() conversion is only needed for legacy reason and should be removed in future + self.client.call('simSetCameraPose', str(camera_name), pose, vehicle_name, external) + + def simSetCameraFov(self, camera_name, fov_degrees, vehicle_name = '', external = False): + """ + - Control the field of view of a selected camera + + Args: + camera_name (str): Name of the camera to be controlled + fov_degrees (float): Value of field of view in degrees + vehicle_name (str, optional): Name of vehicle which the camera corresponds to + external (bool, optional): Whether the camera is an External Camera + """ +#TODO : below str() conversion is only needed for legacy reason and should be removed in future + self.client.call('simSetCameraFov', str(camera_name), fov_degrees, vehicle_name, external) + + def simGetGroundTruthKinematics(self, vehicle_name = ''): + """ + Get Ground truth kinematics of the vehicle + + The position inside the returned KinematicsState is in the frame of the vehicle's starting point + + Args: + vehicle_name (str, optional): Name of the vehicle + + Returns: + KinematicsState: Ground truth of the vehicle + """ + kinematics_state = self.client.call('simGetGroundTruthKinematics', vehicle_name) + return KinematicsState.from_msgpack(kinematics_state) + simGetGroundTruthKinematics.__annotations__ = {'return': KinematicsState} + + def simSetKinematics(self, state, ignore_collision, vehicle_name = ''): + """ + Set the kinematics state of the vehicle + + If you don't want to change position (or orientation) then just set components of position (or orientation) to floating point nan values + + Args: + state (KinematicsState): Desired Pose pf the vehicle + ignore_collision (bool): Whether to ignore any collision or not + vehicle_name (str, optional): Name of the vehicle to move + """ + self.client.call('simSetKinematics', state, ignore_collision, vehicle_name) + + def simGetGroundTruthEnvironment(self, vehicle_name = ''): + """ + Get ground truth environment state + + The position inside the returned EnvironmentState is in the frame of the vehicle's starting point + + Args: + vehicle_name (str, optional): Name of the vehicle + + Returns: + EnvironmentState: Ground truth environment state + """ + env_state = self.client.call('simGetGroundTruthEnvironment', vehicle_name) + return EnvironmentState.from_msgpack(env_state) + simGetGroundTruthEnvironment.__annotations__ = {'return': EnvironmentState} + + +#sensor APIs + def getImuData(self, imu_name = '', vehicle_name = ''): + """ + Args: + imu_name (str, optional): Name of IMU to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + ImuData: + """ + return ImuData.from_msgpack(self.client.call('getImuData', imu_name, vehicle_name)) + + def getBarometerData(self, barometer_name = '', vehicle_name = ''): + """ + Args: + barometer_name (str, optional): Name of Barometer to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + BarometerData: + """ + return BarometerData.from_msgpack(self.client.call('getBarometerData', barometer_name, vehicle_name)) + + def getMagnetometerData(self, magnetometer_name = '', vehicle_name = ''): + """ + Args: + magnetometer_name (str, optional): Name of Magnetometer to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + MagnetometerData: + """ + return MagnetometerData.from_msgpack(self.client.call('getMagnetometerData', magnetometer_name, vehicle_name)) + + def getGpsData(self, gps_name = '', vehicle_name = ''): + """ + Args: + gps_name (str, optional): Name of GPS to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + GpsData: + """ + return GpsData.from_msgpack(self.client.call('getGpsData', gps_name, vehicle_name)) + + def getDistanceSensorData(self, distance_sensor_name = '', vehicle_name = ''): + """ + Args: + distance_sensor_name (str, optional): Name of Distance Sensor to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + DistanceSensorData: + """ + return DistanceSensorData.from_msgpack(self.client.call('getDistanceSensorData', distance_sensor_name, vehicle_name)) + + def getLidarData(self, lidar_name = '', vehicle_name = ''): + """ + Args: + lidar_name (str, optional): Name of Lidar to get data from, specified in settings.json + vehicle_name (str, optional): Name of vehicle to which the sensor corresponds to + + Returns: + LidarData: + """ + return LidarData.from_msgpack(self.client.call('getLidarData', lidar_name, vehicle_name)) + + def simGetLidarSegmentation(self, lidar_name = '', vehicle_name = ''): + """ + NOTE: Deprecated API, use `getLidarData()` API instead + Returns Segmentation ID of each point's collided object in the last Lidar update + + Args: + lidar_name (str, optional): Name of Lidar sensor + vehicle_name (str, optional): Name of the vehicle wth the sensor + + Returns: + list[int]: Segmentation IDs of the objects + """ + logging.warning("simGetLidarSegmentation API is deprecated, use getLidarData() API instead") + return self.getLidarData(lidar_name, vehicle_name).segmentation + +#Plotting APIs + def simFlushPersistentMarkers(self): + """ + Clear any persistent markers - those plotted with setting `is_persistent=True` in the APIs below + """ + self.client.call('simFlushPersistentMarkers') + + def simPlotPoints(self, points, color_rgba=[1.0, 0.0, 0.0, 1.0], size = 10.0, duration = -1.0, is_persistent = False): + """ + Plot a list of 3D points in World NED frame + + Args: + points (list[Vector3r]): List of Vector3r objects + color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 + size (float, optional): Size of plotted point + duration (float, optional): Duration (seconds) to plot for + is_persistent (bool, optional): If set to True, the desired object will be plotted for infinite time. + """ + self.client.call('simPlotPoints', points, color_rgba, size, duration, is_persistent) + + def simPlotLineStrip(self, points, color_rgba=[1.0, 0.0, 0.0, 1.0], thickness = 5.0, duration = -1.0, is_persistent = False): + """ + Plots a line strip in World NED frame, defined from points[0] to points[1], points[1] to points[2], ... , points[n-2] to points[n-1] + + Args: + points (list[Vector3r]): List of 3D locations of line start and end points, specified as Vector3r objects + color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 + thickness (float, optional): Thickness of line + duration (float, optional): Duration (seconds) to plot for + is_persistent (bool, optional): If set to True, the desired object will be plotted for infinite time. + """ + self.client.call('simPlotLineStrip', points, color_rgba, thickness, duration, is_persistent) + + def simPlotLineList(self, points, color_rgba=[1.0, 0.0, 0.0, 1.0], thickness = 5.0, duration = -1.0, is_persistent = False): + """ + Plots a line strip in World NED frame, defined from points[0] to points[1], points[2] to points[3], ... , points[n-2] to points[n-1] + + Args: + points (list[Vector3r]): List of 3D locations of line start and end points, specified as Vector3r objects. Must be even + color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 + thickness (float, optional): Thickness of line + duration (float, optional): Duration (seconds) to plot for + is_persistent (bool, optional): If set to True, the desired object will be plotted for infinite time. + """ + self.client.call('simPlotLineList', points, color_rgba, thickness, duration, is_persistent) + + def simPlotArrows(self, points_start, points_end, color_rgba=[1.0, 0.0, 0.0, 1.0], thickness = 5.0, arrow_size = 2.0, duration = -1.0, is_persistent = False): + """ + Plots a list of arrows in World NED frame, defined from points_start[0] to points_end[0], points_start[1] to points_end[1], ... , points_start[n-1] to points_end[n-1] + + Args: + points_start (list[Vector3r]): List of 3D start positions of arrow start positions, specified as Vector3r objects + points_end (list[Vector3r]): List of 3D end positions of arrow start positions, specified as Vector3r objects + color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 + thickness (float, optional): Thickness of line + arrow_size (float, optional): Size of arrow head + duration (float, optional): Duration (seconds) to plot for + is_persistent (bool, optional): If set to True, the desired object will be plotted for infinite time. + """ + self.client.call('simPlotArrows', points_start, points_end, color_rgba, thickness, arrow_size, duration, is_persistent) + + + def simPlotStrings(self, strings, positions, scale = 5, color_rgba=[1.0, 0.0, 0.0, 1.0], duration = -1.0): + """ + Plots a list of strings at desired positions in World NED frame. + + Args: + strings (list[String], optional): List of strings to plot + positions (list[Vector3r]): List of positions where the strings should be plotted. Should be in one-to-one correspondence with the strings' list + scale (float, optional): Font scale of transform name + color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 + duration (float, optional): Duration (seconds) to plot for + """ + self.client.call('simPlotStrings', strings, positions, scale, color_rgba, duration) + + def simPlotTransforms(self, poses, scale = 5.0, thickness = 5.0, duration = -1.0, is_persistent = False): + """ + Plots a list of transforms in World NED frame. + + Args: + poses (list[Pose]): List of Pose objects representing the transforms to plot + scale (float, optional): Length of transforms' axes + thickness (float, optional): Thickness of transforms' axes + duration (float, optional): Duration (seconds) to plot for + is_persistent (bool, optional): If set to True, the desired object will be plotted for infinite time. + """ + self.client.call('simPlotTransforms', poses, scale, thickness, duration, is_persistent) + + def simPlotTransformsWithNames(self, poses, names, tf_scale = 5.0, tf_thickness = 5.0, text_scale = 10.0, text_color_rgba = [1.0, 0.0, 0.0, 1.0], duration = -1.0): + """ + Plots a list of transforms with their names in World NED frame. + + Args: + poses (list[Pose]): List of Pose objects representing the transforms to plot + names (list[string]): List of strings with one-to-one correspondence to list of poses + tf_scale (float, optional): Length of transforms' axes + tf_thickness (float, optional): Thickness of transforms' axes + text_scale (float, optional): Font scale of transform name + text_color_rgba (list, optional): desired RGBA values from 0.0 to 1.0 for the transform name + duration (float, optional): Duration (seconds) to plot for + """ + self.client.call('simPlotTransformsWithNames', poses, names, tf_scale, tf_thickness, text_scale, text_color_rgba, duration) + + def cancelLastTask(self, vehicle_name = ''): + """ + Cancel previous Async task + + Args: + vehicle_name (str, optional): Name of the vehicle + """ + self.client.call('cancelLastTask', vehicle_name) + +#Recording APIs + def startRecording(self): + """ + Start Recording + + Recording will be done according to the settings + """ + self.client.call('startRecording') + + def stopRecording(self): + """ + Stop Recording + """ + self.client.call('stopRecording') + + def isRecording(self): + """ + Whether Recording is running or not + + Returns: + bool: True if Recording, else False + """ + return self.client.call('isRecording') + + def simSetWind(self, wind): + """ + Set simulated wind, in World frame, NED direction, m/s + + Args: + wind (Vector3r): Wind, in World frame, NED direction, in m/s + """ + self.client.call('simSetWind', wind) + + def simCreateVoxelGrid(self, position, x, y, z, res, of): + """ + Construct and save a binvox-formatted voxel grid of environment + + Args: + position (Vector3r): Position around which voxel grid is centered in m + x, y, z (int): Size of each voxel grid dimension in m + res (float): Resolution of voxel grid in m + of (str): Name of output file to save voxel grid as + + Returns: + bool: True if output written to file successfully, else False + """ + return self.client.call('simCreateVoxelGrid', position, x, y, z, res, of) + +#Add new vehicle via RPC + def simAddVehicle(self, vehicle_name, vehicle_type, pose, pawn_path = ""): + """ + Create vehicle at runtime + + Args: + vehicle_name (str): Name of the vehicle being created + vehicle_type (str): Type of vehicle, e.g. "simpleflight" + pose (Pose): Initial pose of the vehicle + pawn_path (str, optional): Vehicle blueprint path, default empty wbich uses the default blueprint for the vehicle type + + Returns: + bool: Whether vehicle was created + """ + return self.client.call('simAddVehicle', vehicle_name, vehicle_type, pose, pawn_path) + + def listVehicles(self): + """ + Lists the names of current vehicles + + Returns: + list[str]: List containing names of all vehicles + """ + return self.client.call('listVehicles') + + def getSettingsString(self): + """ + Fetch the settings text being used by AirSim + + Returns: + str: Settings text in JSON format + """ + return self.client.call('getSettingsString') + + def simSetExtForce(self, ext_force): + """ + Set arbitrary external forces, in World frame, NED direction. Can be used + for implementing simple payloads. + + Args: + ext_force (Vector3r): Force, in World frame, NED direction, in N + """ + self.client.call('simSetExtForce', ext_force) + +# ----------------------------------- Multirotor APIs --------------------------------------------- +class MultirotorClient(VehicleClient, object): + def __init__(self, ip = "", port = 41451, timeout_value = 3600): + super(MultirotorClient, self).__init__(ip, port, timeout_value) + + def takeoffAsync(self, timeout_sec = 20, vehicle_name = ''): + """ + Takeoff vehicle to 3m above ground. Vehicle should not be moving when this API is used + + Args: + timeout_sec (int, optional): Timeout for the vehicle to reach desired altitude + vehicle_name (str, optional): Name of the vehicle to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('takeoff', timeout_sec, vehicle_name) + + def landAsync(self, timeout_sec = 60, vehicle_name = ''): + """ + Land the vehicle + + Args: + timeout_sec (int, optional): Timeout for the vehicle to land + vehicle_name (str, optional): Name of the vehicle to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('land', timeout_sec, vehicle_name) + + def goHomeAsync(self, timeout_sec = 3e+38, vehicle_name = ''): + """ + Return vehicle to Home i.e. Launch location + + Args: + timeout_sec (int, optional): Timeout for the vehicle to reach desired altitude + vehicle_name (str, optional): Name of the vehicle to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('goHome', timeout_sec, vehicle_name) + +#APIs for control + def moveByVelocityBodyFrameAsync(self, vx, vy, vz, duration, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), vehicle_name = ''): + """ + Args: + vx (float): desired velocity in the X axis of the vehicle's local NED frame. + vy (float): desired velocity in the Y axis of the vehicle's local NED frame. + vz (float): desired velocity in the Z axis of the vehicle's local NED frame. + duration (float): Desired amount of time (seconds), to send this command for + drivetrain (DrivetrainType, optional): + yaw_mode (YawMode, optional): + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByVelocityBodyFrame', vx, vy, vz, duration, drivetrain, yaw_mode, vehicle_name) + + def moveByVelocityZBodyFrameAsync(self, vx, vy, z, duration, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), vehicle_name = ''): + """ + Args: + vx (float): desired velocity in the X axis of the vehicle's local NED frame + vy (float): desired velocity in the Y axis of the vehicle's local NED frame + z (float): desired Z value (in local NED frame of the vehicle) + duration (float): Desired amount of time (seconds), to send this command for + drivetrain (DrivetrainType, optional): + yaw_mode (YawMode, optional): + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + + return self.client.call_async('moveByVelocityZBodyFrame', vx, vy, z, duration, drivetrain, yaw_mode, vehicle_name) + + def moveByAngleZAsync(self, pitch, roll, z, yaw, duration, vehicle_name = ''): + logging.warning("moveByAngleZAsync API is deprecated, use moveByRollPitchYawZAsync() API instead") + return self.client.call_async('moveByRollPitchYawZ', roll, -pitch, -yaw, z, duration, vehicle_name) + + def moveByAngleThrottleAsync(self, pitch, roll, throttle, yaw_rate, duration, vehicle_name = ''): + logging.warning("moveByAngleThrottleAsync API is deprecated, use moveByRollPitchYawrateThrottleAsync() API instead") + return self.client.call_async('moveByRollPitchYawrateThrottle', roll, -pitch, -yaw_rate, throttle, duration, vehicle_name) + + def moveByVelocityAsync(self, vx, vy, vz, duration, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), vehicle_name = ''): + """ + Args: + vx (float): desired velocity in world (NED) X axis + vy (float): desired velocity in world (NED) Y axis + vz (float): desired velocity in world (NED) Z axis + duration (float): Desired amount of time (seconds), to send this command for + drivetrain (DrivetrainType, optional): + yaw_mode (YawMode, optional): + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByVelocity', vx, vy, vz, duration, drivetrain, yaw_mode, vehicle_name) + + def moveByVelocityZAsync(self, vx, vy, z, duration, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), vehicle_name = ''): + return self.client.call_async('moveByVelocityZ', vx, vy, z, duration, drivetrain, yaw_mode, vehicle_name) + + def moveOnPathAsync(self, path, velocity, timeout_sec = 3e+38, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), + lookahead = -1, adaptive_lookahead = 1, vehicle_name = ''): + return self.client.call_async('moveOnPath', path, velocity, timeout_sec, drivetrain, yaw_mode, lookahead, adaptive_lookahead, vehicle_name) + + def moveToPositionAsync(self, x, y, z, velocity, timeout_sec = 3e+38, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), + lookahead = -1, adaptive_lookahead = 1, vehicle_name = ''): + return self.client.call_async('moveToPosition', x, y, z, velocity, timeout_sec, drivetrain, yaw_mode, lookahead, adaptive_lookahead, vehicle_name) + + def moveToGPSAsync(self, latitude, longitude, altitude, velocity, timeout_sec = 3e+38, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), + lookahead = -1, adaptive_lookahead = 1, vehicle_name = ''): + return self.client.call_async('moveToGPS', latitude, longitude, altitude, velocity, timeout_sec, drivetrain, yaw_mode, lookahead, adaptive_lookahead, vehicle_name) + + def moveToZAsync(self, z, velocity, timeout_sec = 3e+38, yaw_mode = YawMode(), lookahead = -1, adaptive_lookahead = 1, vehicle_name = ''): + return self.client.call_async('moveToZ', z, velocity, timeout_sec, yaw_mode, lookahead, adaptive_lookahead, vehicle_name) + + def moveByManualAsync(self, vx_max, vy_max, z_min, duration, drivetrain = DrivetrainType.MaxDegreeOfFreedom, yaw_mode = YawMode(), vehicle_name = ''): + """ + - Read current RC state and use it to control the vehicles. + + Parameters sets up the constraints on velocity and minimum altitude while flying. If RC state is detected to violate these constraints + then that RC state would be ignored. + + Args: + vx_max (float): max velocity allowed in x direction + vy_max (float): max velocity allowed in y direction + vz_max (float): max velocity allowed in z direction + z_min (float): min z allowed for vehicle position + duration (float): after this duration vehicle would switch back to non-manual mode + drivetrain (DrivetrainType): when ForwardOnly, vehicle rotates itself so that its front is always facing the direction of travel. If MaxDegreeOfFreedom then it doesn't do that (crab-like movement) + yaw_mode (YawMode): Specifies if vehicle should face at given angle (is_rate=False) or should be rotating around its axis at given rate (is_rate=True) + vehicle_name (str, optional): Name of the multirotor to send this command to + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByManual', vx_max, vy_max, z_min, duration, drivetrain, yaw_mode, vehicle_name) + + def rotateToYawAsync(self, yaw, timeout_sec = 3e+38, margin = 5, vehicle_name = ''): + return self.client.call_async('rotateToYaw', yaw, timeout_sec, margin, vehicle_name) + + def rotateByYawRateAsync(self, yaw_rate, duration, vehicle_name = ''): + return self.client.call_async('rotateByYawRate', yaw_rate, duration, vehicle_name) + + def hoverAsync(self, vehicle_name = ''): + return self.client.call_async('hover', vehicle_name) + + def moveByRC(self, rcdata = RCData(), vehicle_name = ''): + return self.client.call('moveByRC', rcdata, vehicle_name) + +#low - level control API + def moveByMotorPWMsAsync(self, front_right_pwm, rear_left_pwm, front_left_pwm, rear_right_pwm, duration, vehicle_name = ''): + """ + - Directly control the motors using PWM values + + Args: + front_right_pwm (float): PWM value for the front right motor (between 0.0 to 1.0) + rear_left_pwm (float): PWM value for the rear left motor (between 0.0 to 1.0) + front_left_pwm (float): PWM value for the front left motor (between 0.0 to 1.0) + rear_right_pwm (float): PWM value for the rear right motor (between 0.0 to 1.0) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByMotorPWMs', front_right_pwm, rear_left_pwm, front_left_pwm, rear_right_pwm, duration, vehicle_name) + + def moveByRollPitchYawZAsync(self, roll, pitch, yaw, z, duration, vehicle_name = ''): + """ + - z is given in local NED frame of the vehicle. + - Roll angle, pitch angle, and yaw angle set points are given in **radians**, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll (float): Desired roll angle, in radians. + pitch (float): Desired pitch angle, in radians. + yaw (float): Desired yaw angle, in radians. + z (float): Desired Z value (in local NED frame of the vehicle) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByRollPitchYawZ', roll, -pitch, -yaw, z, duration, vehicle_name) + + def moveByRollPitchYawThrottleAsync(self, roll, pitch, yaw, throttle, duration, vehicle_name = ''): + """ + - Desired throttle is between 0.0 to 1.0 + - Roll angle, pitch angle, and yaw angle are given in **degrees** when using PX4 and in **radians** when using SimpleFlight, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll (float): Desired roll angle. + pitch (float): Desired pitch angle. + yaw (float): Desired yaw angle. + throttle (float): Desired throttle (between 0.0 to 1.0) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByRollPitchYawThrottle', roll, -pitch, -yaw, throttle, duration, vehicle_name) + + def moveByRollPitchYawrateThrottleAsync(self, roll, pitch, yaw_rate, throttle, duration, vehicle_name = ''): + """ + - Desired throttle is between 0.0 to 1.0 + - Roll angle, pitch angle, and yaw rate set points are given in **radians**, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll (float): Desired roll angle, in radians. + pitch (float): Desired pitch angle, in radians. + yaw_rate (float): Desired yaw rate, in radian per second. + throttle (float): Desired throttle (between 0.0 to 1.0) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByRollPitchYawrateThrottle', roll, -pitch, -yaw_rate, throttle, duration, vehicle_name) + + def moveByRollPitchYawrateZAsync(self, roll, pitch, yaw_rate, z, duration, vehicle_name = ''): + """ + - z is given in local NED frame of the vehicle. + - Roll angle, pitch angle, and yaw rate set points are given in **radians**, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll (float): Desired roll angle, in radians. + pitch (float): Desired pitch angle, in radians. + yaw_rate (float): Desired yaw rate, in radian per second. + z (float): Desired Z value (in local NED frame of the vehicle) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByRollPitchYawrateZ', roll, -pitch, -yaw_rate, z, duration, vehicle_name) + + def moveByAngleRatesZAsync(self, roll_rate, pitch_rate, yaw_rate, z, duration, vehicle_name = ''): + """ + - z is given in local NED frame of the vehicle. + - Roll rate, pitch rate, and yaw rate set points are given in **radians**, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll_rate (float): Desired roll rate, in radians / second + pitch_rate (float): Desired pitch rate, in radians / second + yaw_rate (float): Desired yaw rate, in radians / second + z (float): Desired Z value (in local NED frame of the vehicle) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByAngleRatesZ', roll_rate, -pitch_rate, -yaw_rate, z, duration, vehicle_name) + + def moveByAngleRatesThrottleAsync(self, roll_rate, pitch_rate, yaw_rate, throttle, duration, vehicle_name = ''): + """ + - Desired throttle is between 0.0 to 1.0 + - Roll rate, pitch rate, and yaw rate set points are given in **radians**, in the body frame. + - The body frame follows the Front Left Up (FLU) convention, and right-handedness. + + - Frame Convention: + - X axis is along the **Front** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **roll** angle. + | Hence, rolling with a positive angle is equivalent to translating in the **right** direction, w.r.t. our FLU body frame. + + - Y axis is along the **Left** direction of the quadrotor. + + | Clockwise rotation about this axis defines a positive **pitch** angle. + | Hence, pitching with a positive angle is equivalent to translating in the **front** direction, w.r.t. our FLU body frame. + + - Z axis is along the **Up** direction. + + | Clockwise rotation about this axis defines a positive **yaw** angle. + | Hence, yawing with a positive angle is equivalent to rotated towards the **left** direction wrt our FLU body frame. Or in an anticlockwise fashion in the body XY / FL plane. + + Args: + roll_rate (float): Desired roll rate, in radians / second + pitch_rate (float): Desired pitch rate, in radians / second + yaw_rate (float): Desired yaw rate, in radians / second + throttle (float): Desired throttle (between 0.0 to 1.0) + duration (float): Desired amount of time (seconds), to send this command for + vehicle_name (str, optional): Name of the multirotor to send this command to + + Returns: + msgpackrpc.future.Future: future. call .join() to wait for method to finish. Example: client.METHOD().join() + """ + return self.client.call_async('moveByAngleRatesThrottle', roll_rate, -pitch_rate, -yaw_rate, throttle, duration, vehicle_name) + + def setAngleRateControllerGains(self, angle_rate_gains=AngleRateControllerGains(), vehicle_name = ''): + """ + - Modifying these gains will have an affect on *ALL* move*() APIs. + This is because any velocity setpoint is converted to an angle level setpoint which is tracked with an angle level controllers. + That angle level setpoint is itself tracked with and angle rate controller. + - This function should only be called if the default angle rate control PID gains need to be modified. + + Args: + angle_rate_gains (AngleRateControllerGains): + - Correspond to the roll, pitch, yaw axes, defined in the body frame. + - Pass AngleRateControllerGains() to reset gains to default recommended values. + vehicle_name (str, optional): Name of the multirotor to send this command to + """ + self.client.call('setAngleRateControllerGains', *(angle_rate_gains.to_lists()+(vehicle_name,))) + + def setAngleLevelControllerGains(self, angle_level_gains=AngleLevelControllerGains(), vehicle_name = ''): + """ + - Sets angle level controller gains (used by any API setting angle references - for ex: moveByRollPitchYawZAsync(), moveByRollPitchYawThrottleAsync(), etc) + - Modifying these gains will also affect the behaviour of moveByVelocityAsync() API. + This is because the AirSim flight controller will track velocity setpoints by converting them to angle set points. + - This function should only be called if the default angle level control PID gains need to be modified. + - Passing AngleLevelControllerGains() sets gains to default airsim values. + + Args: + angle_level_gains (AngleLevelControllerGains): + - Correspond to the roll, pitch, yaw axes, defined in the body frame. + - Pass AngleLevelControllerGains() to reset gains to default recommended values. + vehicle_name (str, optional): Name of the multirotor to send this command to + """ + self.client.call('setAngleLevelControllerGains', *(angle_level_gains.to_lists()+(vehicle_name,))) + + def setVelocityControllerGains(self, velocity_gains=VelocityControllerGains(), vehicle_name = ''): + """ + - Sets velocity controller gains for moveByVelocityAsync(). + - This function should only be called if the default velocity control PID gains need to be modified. + - Passing VelocityControllerGains() sets gains to default airsim values. + + Args: + velocity_gains (VelocityControllerGains): + - Correspond to the world X, Y, Z axes. + - Pass VelocityControllerGains() to reset gains to default recommended values. + - Modifying velocity controller gains will have an affect on the behaviour of moveOnSplineAsync() and moveOnSplineVelConstraintsAsync(), as they both use velocity control to track the trajectory. + vehicle_name (str, optional): Name of the multirotor to send this command to + """ + self.client.call('setVelocityControllerGains', *(velocity_gains.to_lists()+(vehicle_name,))) + + + def setPositionControllerGains(self, position_gains=PositionControllerGains(), vehicle_name = ''): + """ + Sets position controller gains for moveByPositionAsync. + This function should only be called if the default position control PID gains need to be modified. + + Args: + position_gains (PositionControllerGains): + - Correspond to the X, Y, Z axes. + - Pass PositionControllerGains() to reset gains to default recommended values. + vehicle_name (str, optional): Name of the multirotor to send this command to + """ + self.client.call('setPositionControllerGains', *(position_gains.to_lists()+(vehicle_name,))) + +#query vehicle state + def getMultirotorState(self, vehicle_name = ''): + """ + The position inside the returned MultirotorState is in the frame of the vehicle's starting point + + Args: + vehicle_name (str, optional): Vehicle to get the state of + + Returns: + MultirotorState: + """ + return MultirotorState.from_msgpack(self.client.call('getMultirotorState', vehicle_name)) + getMultirotorState.__annotations__ = {'return': MultirotorState} +#query rotor states + def getRotorStates(self, vehicle_name = ''): + """ + Used to obtain the current state of all a multirotor's rotors. The state includes the speeds, + thrusts and torques for all rotors. + + Args: + vehicle_name (str, optional): Vehicle to get the rotor state of + + Returns: + RotorStates: Containing a timestamp and the speed, thrust and torque of all rotors. + """ + return RotorStates.from_msgpack(self.client.call('getRotorStates', vehicle_name)) + getRotorStates.__annotations__ = {'return': RotorStates} + +#----------------------------------- Car APIs --------------------------------------------- +class CarClient(VehicleClient, object): + def __init__(self, ip = "", port = 41451, timeout_value = 3600): + super(CarClient, self).__init__(ip, port, timeout_value) + + def setCarControls(self, controls, vehicle_name = ''): + """ + Control the car using throttle, steering, brake, etc. + + Args: + controls (CarControls): Struct containing control values + vehicle_name (str, optional): Name of vehicle to be controlled + """ + self.client.call('setCarControls', controls, vehicle_name) + + def getCarState(self, vehicle_name = ''): + """ + The position inside the returned CarState is in the frame of the vehicle's starting point + + Args: + vehicle_name (str, optional): Name of vehicle + + Returns: + CarState: + """ + state_raw = self.client.call('getCarState', vehicle_name) + return CarState.from_msgpack(state_raw) + + def getCarControls(self, vehicle_name=''): + """ + Args: + vehicle_name (str, optional): Name of vehicle + + Returns: + CarControls: + """ + controls_raw = self.client.call('getCarControls', vehicle_name) + return CarControls.from_msgpack(controls_raw) diff --git a/airsim/pfm.py b/airsim/pfm.py new file mode 100644 index 0000000..6f9f963 --- /dev/null +++ b/airsim/pfm.py @@ -0,0 +1,85 @@ +import numpy as np +import matplotlib.pyplot as plt +import re +import sys +import pdb + + +def read_pfm(file): + """ Read a pfm file """ + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + header = str(bytes.decode(header, encoding='utf-8')) + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + pattern = r'^(\d+)\s(\d+)\s$' + temp_str = str(bytes.decode(file.readline(), encoding='utf-8')) + dim_match = re.match(pattern, temp_str) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + temp_str += str(bytes.decode(file.readline(), encoding='utf-8')) + dim_match = re.match(pattern, temp_str) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header: width, height cannot be found') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + # DEY: I don't know why this was there. + file.close() + + return data, scale + + +def write_pfm(file, image, scale=1): + """ Write a pfm file """ + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write(bytes('PF\n', 'UTF-8') if color else bytes('Pf\n', 'UTF-8')) + temp_str = '%d %d\n' % (image.shape[1], image.shape[0]) + file.write(bytes(temp_str, 'UTF-8')) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + temp_str = '%f\n' % scale + file.write(bytes(temp_str, 'UTF-8')) + + image.tofile(file) diff --git a/airsim/types.py b/airsim/types.py new file mode 100644 index 0000000..7aef005 --- /dev/null +++ b/airsim/types.py @@ -0,0 +1,580 @@ +from __future__ import print_function +import msgpackrpc #install as admin: pip install msgpack-rpc-python +import numpy as np #pip install numpy +import math + +class MsgpackMixin: + def __repr__(self): + from pprint import pformat + return "<" + type(self).__name__ + "> " + pformat(vars(self), indent=4, width=1) + + def to_msgpack(self, *args, **kwargs): + return self.__dict__ + + @classmethod + def from_msgpack(cls, encoded): + obj = cls() + #obj.__dict__ = {k.decode('utf-8'): (from_msgpack(v.__class__, v) if hasattr(v, "__dict__") else v) for k, v in encoded.items()} + obj.__dict__ = { k : (v if not isinstance(v, dict) else getattr(getattr(obj, k).__class__, "from_msgpack")(v)) for k, v in encoded.items()} + #return cls(**msgpack.unpack(encoded)) + return obj + +class _ImageType(type): + @property + def Scene(cls): + return 0 + def DepthPlanar(cls): + return 1 + def DepthPerspective(cls): + return 2 + def DepthVis(cls): + return 3 + def DisparityNormalized(cls): + return 4 + def Segmentation(cls): + return 5 + def SurfaceNormals(cls): + return 6 + def Infrared(cls): + return 7 + def OpticalFlow(cls): + return 8 + def OpticalFlowVis(cls): + return 9 + + def __getattr__(self, key): + if key == 'DepthPlanner': + print('\033[31m'+"DepthPlanner has been (correctly) renamed to DepthPlanar. Please use ImageType.DepthPlanar instead."+'\033[0m') + raise AttributeError + +class ImageType(metaclass=_ImageType): + Scene = 0 + DepthPlanar = 1 + DepthPerspective = 2 + DepthVis = 3 + DisparityNormalized = 4 + Segmentation = 5 + SurfaceNormals = 6 + Infrared = 7 + OpticalFlow = 8 + OpticalFlowVis = 9 + +class DrivetrainType: + MaxDegreeOfFreedom = 0 + ForwardOnly = 1 + +class LandedState: + Landed = 0 + Flying = 1 + +class WeatherParameter: + Rain = 0 + Roadwetness = 1 + Snow = 2 + RoadSnow = 3 + MapleLeaf = 4 + RoadLeaf = 5 + Dust = 6 + Fog = 7 + Enabled = 8 + +class Vector2r(MsgpackMixin): + x_val = 0.0 + y_val = 0.0 + + def __init__(self, x_val = 0.0, y_val = 0.0): + self.x_val = x_val + self.y_val = y_val + +class Vector3r(MsgpackMixin): + x_val = 0.0 + y_val = 0.0 + z_val = 0.0 + + def __init__(self, x_val = 0.0, y_val = 0.0, z_val = 0.0): + self.x_val = x_val + self.y_val = y_val + self.z_val = z_val + + @staticmethod + def nanVector3r(): + return Vector3r(np.nan, np.nan, np.nan) + + def containsNan(self): + return (math.isnan(self.x_val) or math.isnan(self.y_val) or math.isnan(self.z_val)) + + def __add__(self, other): + return Vector3r(self.x_val + other.x_val, self.y_val + other.y_val, self.z_val + other.z_val) + + def __sub__(self, other): + return Vector3r(self.x_val - other.x_val, self.y_val - other.y_val, self.z_val - other.z_val) + + def __truediv__(self, other): + if type(other) in [int, float] + np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float']: + return Vector3r( self.x_val / other, self.y_val / other, self.z_val / other) + else: + raise TypeError('unsupported operand type(s) for /: %s and %s' % ( str(type(self)), str(type(other))) ) + + def __mul__(self, other): + if type(other) in [int, float] + np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float']: + return Vector3r(self.x_val*other, self.y_val*other, self.z_val*other) + else: + raise TypeError('unsupported operand type(s) for *: %s and %s' % ( str(type(self)), str(type(other))) ) + + def dot(self, other): + if type(self) == type(other): + return self.x_val*other.x_val + self.y_val*other.y_val + self.z_val*other.z_val + else: + raise TypeError('unsupported operand type(s) for \'dot\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def cross(self, other): + if type(self) == type(other): + cross_product = np.cross(self.to_numpy_array(), other.to_numpy_array()) + return Vector3r(cross_product[0], cross_product[1], cross_product[2]) + else: + raise TypeError('unsupported operand type(s) for \'cross\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def get_length(self): + return ( self.x_val**2 + self.y_val**2 + self.z_val**2 )**0.5 + + def distance_to(self, other): + return ( (self.x_val-other.x_val)**2 + (self.y_val-other.y_val)**2 + (self.z_val-other.z_val)**2 )**0.5 + + def to_Quaternionr(self): + return Quaternionr(self.x_val, self.y_val, self.z_val, 0) + + def to_numpy_array(self): + return np.array([self.x_val, self.y_val, self.z_val], dtype=np.float32) + + def __iter__(self): + return iter((self.x_val, self.y_val, self.z_val)) + +class Quaternionr(MsgpackMixin): + w_val = 0.0 + x_val = 0.0 + y_val = 0.0 + z_val = 0.0 + + def __init__(self, x_val = 0.0, y_val = 0.0, z_val = 0.0, w_val = 1.0): + self.x_val = x_val + self.y_val = y_val + self.z_val = z_val + self.w_val = w_val + + @staticmethod + def nanQuaternionr(): + return Quaternionr(np.nan, np.nan, np.nan, np.nan) + + def containsNan(self): + return (math.isnan(self.w_val) or math.isnan(self.x_val) or math.isnan(self.y_val) or math.isnan(self.z_val)) + + def __add__(self, other): + if type(self) == type(other): + return Quaternionr( self.x_val+other.x_val, self.y_val+other.y_val, self.z_val+other.z_val, self.w_val+other.w_val ) + else: + raise TypeError('unsupported operand type(s) for +: %s and %s' % ( str(type(self)), str(type(other))) ) + + def __mul__(self, other): + if type(self) == type(other): + t, x, y, z = self.w_val, self.x_val, self.y_val, self.z_val + a, b, c, d = other.w_val, other.x_val, other.y_val, other.z_val + return Quaternionr( w_val = a*t - b*x - c*y - d*z, + x_val = b*t + a*x + d*y - c*z, + y_val = c*t + a*y + b*z - d*x, + z_val = d*t + z*a + c*x - b*y) + else: + raise TypeError('unsupported operand type(s) for *: %s and %s' % ( str(type(self)), str(type(other))) ) + + def __truediv__(self, other): + if type(other) == type(self): + return self * other.inverse() + elif type(other) in [int, float] + np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float']: + return Quaternionr( self.x_val / other, self.y_val / other, self.z_val / other, self.w_val / other) + else: + raise TypeError('unsupported operand type(s) for /: %s and %s' % ( str(type(self)), str(type(other))) ) + + def dot(self, other): + if type(self) == type(other): + return self.x_val*other.x_val + self.y_val*other.y_val + self.z_val*other.z_val + self.w_val*other.w_val + else: + raise TypeError('unsupported operand type(s) for \'dot\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def cross(self, other): + if type(self) == type(other): + return (self * other - other * self) / 2 + else: + raise TypeError('unsupported operand type(s) for \'cross\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def outer_product(self, other): + if type(self) == type(other): + return ( self.inverse()*other - other.inverse()*self ) / 2 + else: + raise TypeError('unsupported operand type(s) for \'outer_product\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def rotate(self, other): + if type(self) == type(other): + if other.get_length() == 1: + return other * self * other.inverse() + else: + raise ValueError('length of the other Quaternionr must be 1') + else: + raise TypeError('unsupported operand type(s) for \'rotate\': %s and %s' % ( str(type(self)), str(type(other))) ) + + def conjugate(self): + return Quaternionr(-self.x_val, -self.y_val, -self.z_val, self.w_val) + + def star(self): + return self.conjugate() + + def inverse(self): + return self.star() / self.dot(self) + + def sgn(self): + return self/self.get_length() + + def get_length(self): + return ( self.x_val**2 + self.y_val**2 + self.z_val**2 + self.w_val**2 )**0.5 + + def to_numpy_array(self): + return np.array([self.x_val, self.y_val, self.z_val, self.w_val], dtype=np.float32) + + def __iter__(self): + return iter((self.x_val, self.y_val, self.z_val, self.w_val)) + +class Pose(MsgpackMixin): + position = Vector3r() + orientation = Quaternionr() + + def __init__(self, position_val = None, orientation_val = None): + position_val = position_val if position_val is not None else Vector3r() + orientation_val = orientation_val if orientation_val is not None else Quaternionr() + self.position = position_val + self.orientation = orientation_val + + @staticmethod + def nanPose(): + return Pose(Vector3r.nanVector3r(), Quaternionr.nanQuaternionr()) + + def containsNan(self): + return (self.position.containsNan() or self.orientation.containsNan()) + + def __iter__(self): + return iter((self.position, self.orientation)) + +class CollisionInfo(MsgpackMixin): + has_collided = False + normal = Vector3r() + impact_point = Vector3r() + position = Vector3r() + penetration_depth = 0.0 + time_stamp = 0.0 + object_name = "" + object_id = -1 + +class GeoPoint(MsgpackMixin): + latitude = 0.0 + longitude = 0.0 + altitude = 0.0 + +class YawMode(MsgpackMixin): + is_rate = True + yaw_or_rate = 0.0 + def __init__(self, is_rate = True, yaw_or_rate = 0.0): + self.is_rate = is_rate + self.yaw_or_rate = yaw_or_rate + +class RCData(MsgpackMixin): + timestamp = 0 + pitch, roll, throttle, yaw = (0.0,)*4 #init 4 variable to 0.0 + switch1, switch2, switch3, switch4 = (0,)*4 + switch5, switch6, switch7, switch8 = (0,)*4 + is_initialized = False + is_valid = False + def __init__(self, timestamp = 0, pitch = 0.0, roll = 0.0, throttle = 0.0, yaw = 0.0, switch1 = 0, + switch2 = 0, switch3 = 0, switch4 = 0, switch5 = 0, switch6 = 0, switch7 = 0, switch8 = 0, is_initialized = False, is_valid = False): + self.timestamp = timestamp + self.pitch = pitch + self.roll = roll + self.throttle = throttle + self.yaw = yaw + self.switch1 = switch1 + self.switch2 = switch2 + self.switch3 = switch3 + self.switch4 = switch4 + self.switch5 = switch5 + self.switch6 = switch6 + self.switch7 = switch7 + self.switch8 = switch8 + self.is_initialized = is_initialized + self.is_valid = is_valid + +class ImageRequest(MsgpackMixin): + camera_name = '0' + image_type = ImageType.Scene + pixels_as_float = False + compress = False + + def __init__(self, camera_name, image_type, pixels_as_float = False, compress = True): + # todo: in future remove str(), it's only for compatibility to pre v1.2 + self.camera_name = str(camera_name) + self.image_type = image_type + self.pixels_as_float = pixels_as_float + self.compress = compress + + +class ImageResponse(MsgpackMixin): + image_data_uint8 = np.uint8(0) + image_data_float = 0.0 + camera_position = Vector3r() + camera_orientation = Quaternionr() + time_stamp = np.uint64(0) + message = '' + pixels_as_float = 0.0 + compress = True + width = 0 + height = 0 + image_type = ImageType.Scene + +class CarControls(MsgpackMixin): + throttle = 0.0 + steering = 0.0 + brake = 0.0 + handbrake = False + is_manual_gear = False + manual_gear = 0 + gear_immediate = True + + def __init__(self, throttle = 0, steering = 0, brake = 0, + handbrake = False, is_manual_gear = False, manual_gear = 0, gear_immediate = True): + self.throttle = throttle + self.steering = steering + self.brake = brake + self.handbrake = handbrake + self.is_manual_gear = is_manual_gear + self.manual_gear = manual_gear + self.gear_immediate = gear_immediate + + + def set_throttle(self, throttle_val, forward): + if (forward): + self.is_manual_gear = False + self.manual_gear = 0 + self.throttle = abs(throttle_val) + else: + self.is_manual_gear = False + self.manual_gear = -1 + self.throttle = - abs(throttle_val) + +class KinematicsState(MsgpackMixin): + position = Vector3r() + orientation = Quaternionr() + linear_velocity = Vector3r() + angular_velocity = Vector3r() + linear_acceleration = Vector3r() + angular_acceleration = Vector3r() + +class EnvironmentState(MsgpackMixin): + position = Vector3r() + geo_point = GeoPoint() + gravity = Vector3r() + air_pressure = 0.0 + temperature = 0.0 + air_density = 0.0 + +class CarState(MsgpackMixin): + speed = 0.0 + gear = 0 + rpm = 0.0 + maxrpm = 0.0 + handbrake = False + collision = CollisionInfo() + kinematics_estimated = KinematicsState() + timestamp = np.uint64(0) + +class MultirotorState(MsgpackMixin): + collision = CollisionInfo() + kinematics_estimated = KinematicsState() + gps_location = GeoPoint() + timestamp = np.uint64(0) + landed_state = LandedState.Landed + rc_data = RCData() + ready = False + ready_message = "" + can_arm = False + +class RotorStates(MsgpackMixin): + timestamp = np.uint64(0) + rotors = [] + +class ProjectionMatrix(MsgpackMixin): + matrix = [] + +class CameraInfo(MsgpackMixin): + pose = Pose() + fov = -1 + proj_mat = ProjectionMatrix() + +class LidarData(MsgpackMixin): + point_cloud = 0.0 + time_stamp = np.uint64(0) + pose = Pose() + segmentation = 0 + +class ImuData(MsgpackMixin): + time_stamp = np.uint64(0) + orientation = Quaternionr() + angular_velocity = Vector3r() + linear_acceleration = Vector3r() + +class BarometerData(MsgpackMixin): + time_stamp = np.uint64(0) + altitude = Quaternionr() + pressure = Vector3r() + qnh = Vector3r() + +class MagnetometerData(MsgpackMixin): + time_stamp = np.uint64(0) + magnetic_field_body = Vector3r() + magnetic_field_covariance = 0.0 + +class GnssFixType(MsgpackMixin): + GNSS_FIX_NO_FIX = 0 + GNSS_FIX_TIME_ONLY = 1 + GNSS_FIX_2D_FIX = 2 + GNSS_FIX_3D_FIX = 3 + +class GnssReport(MsgpackMixin): + geo_point = GeoPoint() + eph = 0.0 + epv = 0.0 + velocity = Vector3r() + fix_type = GnssFixType() + time_utc = np.uint64(0) + +class GpsData(MsgpackMixin): + time_stamp = np.uint64(0) + gnss = GnssReport() + is_valid = False + +class DistanceSensorData(MsgpackMixin): + time_stamp = np.uint64(0) + distance = 0.0 + min_distance = 0.0 + max_distance = 0.0 + relative_pose = Pose() + +class Box2D(MsgpackMixin): + min = Vector2r() + max = Vector2r() + +class Box3D(MsgpackMixin): + min = Vector3r() + max = Vector3r() + +class DetectionInfo(MsgpackMixin): + name = '' + geo_point = GeoPoint() + box2D = Box2D() + box3D = Box3D() + relative_pose = Pose() + +class PIDGains(): + """ + Struct to store values of PID gains. Used to transmit controller gain values while instantiating + AngleLevel/AngleRate/Velocity/PositionControllerGains objects. + + Attributes: + kP (float): Proportional gain + kI (float): Integrator gain + kD (float): Derivative gain + """ + def __init__(self, kp, ki, kd): + self.kp = kp + self.ki = ki + self.kd = kd + + def to_list(self): + return [self.kp, self.ki, self.kd] + +class AngleRateControllerGains(): + """ + Struct to contain controller gains used by angle level PID controller + + Attributes: + roll_gains (PIDGains): kP, kI, kD for roll axis + pitch_gains (PIDGains): kP, kI, kD for pitch axis + yaw_gains (PIDGains): kP, kI, kD for yaw axis + """ + def __init__(self, roll_gains = PIDGains(0.25, 0, 0), + pitch_gains = PIDGains(0.25, 0, 0), + yaw_gains = PIDGains(0.25, 0, 0)): + self.roll_gains = roll_gains + self.pitch_gains = pitch_gains + self.yaw_gains = yaw_gains + + def to_lists(self): + return [self.roll_gains.kp, self.pitch_gains.kp, self.yaw_gains.kp], [self.roll_gains.ki, self.pitch_gains.ki, self.yaw_gains.ki], [self.roll_gains.kd, self.pitch_gains.kd, self.yaw_gains.kd] + +class AngleLevelControllerGains(): + """ + Struct to contain controller gains used by angle rate PID controller + + Attributes: + roll_gains (PIDGains): kP, kI, kD for roll axis + pitch_gains (PIDGains): kP, kI, kD for pitch axis + yaw_gains (PIDGains): kP, kI, kD for yaw axis + """ + def __init__(self, roll_gains = PIDGains(2.5, 0, 0), + pitch_gains = PIDGains(2.5, 0, 0), + yaw_gains = PIDGains(2.5, 0, 0)): + self.roll_gains = roll_gains + self.pitch_gains = pitch_gains + self.yaw_gains = yaw_gains + + def to_lists(self): + return [self.roll_gains.kp, self.pitch_gains.kp, self.yaw_gains.kp], [self.roll_gains.ki, self.pitch_gains.ki, self.yaw_gains.ki], [self.roll_gains.kd, self.pitch_gains.kd, self.yaw_gains.kd] + +class VelocityControllerGains(): + """ + Struct to contain controller gains used by velocity PID controller + + Attributes: + x_gains (PIDGains): kP, kI, kD for X axis + y_gains (PIDGains): kP, kI, kD for Y axis + z_gains (PIDGains): kP, kI, kD for Z axis + """ + def __init__(self, x_gains = PIDGains(0.2, 0, 0), + y_gains = PIDGains(0.2, 0, 0), + z_gains = PIDGains(2.0, 2.0, 0)): + self.x_gains = x_gains + self.y_gains = y_gains + self.z_gains = z_gains + + def to_lists(self): + return [self.x_gains.kp, self.y_gains.kp, self.z_gains.kp], [self.x_gains.ki, self.y_gains.ki, self.z_gains.ki], [self.x_gains.kd, self.y_gains.kd, self.z_gains.kd] + +class PositionControllerGains(): + """ + Struct to contain controller gains used by position PID controller + + Attributes: + x_gains (PIDGains): kP, kI, kD for X axis + y_gains (PIDGains): kP, kI, kD for Y axis + z_gains (PIDGains): kP, kI, kD for Z axis + """ + def __init__(self, x_gains = PIDGains(0.25, 0, 0), + y_gains = PIDGains(0.25, 0, 0), + z_gains = PIDGains(0.25, 0, 0)): + self.x_gains = x_gains + self.y_gains = y_gains + self.z_gains = z_gains + + def to_lists(self): + return [self.x_gains.kp, self.y_gains.kp, self.z_gains.kp], [self.x_gains.ki, self.y_gains.ki, self.z_gains.ki], [self.x_gains.kd, self.y_gains.kd, self.z_gains.kd] + +class MeshPositionVertexBuffersResponse(MsgpackMixin): + position = Vector3r() + orientation = Quaternionr() + vertices = 0.0 + indices = 0.0 + name = '' diff --git a/airsim/utils.py b/airsim/utils.py new file mode 100644 index 0000000..7f866d7 --- /dev/null +++ b/airsim/utils.py @@ -0,0 +1,208 @@ +import numpy as np #pip install numpy +import math +import time +import sys +import os +import inspect +import types +import re +import logging + +from .types import * + + +def string_to_uint8_array(bstr): + return np.fromstring(bstr, np.uint8) + +def string_to_float_array(bstr): + return np.fromstring(bstr, np.float32) + +def list_to_2d_float_array(flst, width, height): + return np.reshape(np.asarray(flst, np.float32), (height, width)) + +def get_pfm_array(response): + return list_to_2d_float_array(response.image_data_float, response.width, response.height) + + +def get_public_fields(obj): + return [attr for attr in dir(obj) + if not (attr.startswith("_") + or inspect.isbuiltin(attr) + or inspect.isfunction(attr) + or inspect.ismethod(attr))] + + + +def to_dict(obj): + return dict([attr, getattr(obj, attr)] for attr in get_public_fields(obj)) + + +def to_str(obj): + return str(to_dict(obj)) + + +def write_file(filename, bstr): + """ + Write binary data to file. + Used for writing compressed PNG images + """ + with open(filename, 'wb') as afile: + afile.write(bstr) + +# helper method for converting getOrientation to roll/pitch/yaw +# https:#en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + +def to_eularian_angles(q): + z = q.z_val + y = q.y_val + x = q.x_val + w = q.w_val + ysqr = y * y + + # roll (x-axis rotation) + t0 = +2.0 * (w*x + y*z) + t1 = +1.0 - 2.0*(x*x + ysqr) + roll = math.atan2(t0, t1) + + # pitch (y-axis rotation) + t2 = +2.0 * (w*y - z*x) + if (t2 > 1.0): + t2 = 1 + if (t2 < -1.0): + t2 = -1.0 + pitch = math.asin(t2) + + # yaw (z-axis rotation) + t3 = +2.0 * (w*z + x*y) + t4 = +1.0 - 2.0 * (ysqr + z*z) + yaw = math.atan2(t3, t4) + + return (pitch, roll, yaw) + + +def to_quaternion(pitch, roll, yaw): + t0 = math.cos(yaw * 0.5) + t1 = math.sin(yaw * 0.5) + t2 = math.cos(roll * 0.5) + t3 = math.sin(roll * 0.5) + t4 = math.cos(pitch * 0.5) + t5 = math.sin(pitch * 0.5) + + q = Quaternionr() + q.w_val = t0 * t2 * t4 + t1 * t3 * t5 #w + q.x_val = t0 * t3 * t4 - t1 * t2 * t5 #x + q.y_val = t0 * t2 * t5 + t1 * t3 * t4 #y + q.z_val = t1 * t2 * t4 - t0 * t3 * t5 #z + return q + + +def wait_key(message = ''): + ''' Wait for a key press on the console and return it. ''' + if message != '': + print (message) + + result = None + if os.name == 'nt': + import msvcrt + result = msvcrt.getch() + else: + import termios + fd = sys.stdin.fileno() + + oldterm = termios.tcgetattr(fd) + newattr = termios.tcgetattr(fd) + newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO + termios.tcsetattr(fd, termios.TCSANOW, newattr) + + try: + result = sys.stdin.read(1) + except IOError: + pass + finally: + termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm) + + return result + + +def read_pfm(file): + """ Read a pfm file """ + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + header = str(bytes.decode(header, encoding='utf-8')) + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + temp_str = str(bytes.decode(file.readline(), encoding='utf-8')) + dim_match = re.match(r'^(\d+)\s(\d+)\s$', temp_str) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + # DEY: I don't know why this was there. + file.close() + + return data, scale + + +def write_pfm(file, image, scale=1): + """ Write a pfm file """ + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # grayscale + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) + temp_str = '%d %d\n' % (image.shape[1], image.shape[0]) + file.write(temp_str.encode('utf-8')) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + temp_str = '%f\n' % scale + file.write(temp_str.encode('utf-8')) + + image.tofile(file) + + +def write_png(filename, image): + """ image must be numpy array H X W X channels + """ + import cv2 # pip install opencv-python + + ret = cv2.imwrite(filename, image) + if not ret: + logging.error(f"Writing PNG file {filename} failed") diff --git a/data/coco.yaml b/data/coco.yaml index 2ccc647..b4bbbae 100644 --- a/data/coco.yaml +++ b/data/coco.yaml @@ -1,30 +1,11 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license -# COCO 2017 dataset http://cocodataset.org -# Example usage: python train.py --data coco.yaml -# parent -# ├── yolov5 -# └── datasets -# └── coco ← downloads here - - -# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] path: ../datasets/coco # dataset root dir train: train2017.txt # train images (relative to 'path') 118287 images val: val2017.txt # train images (relative to 'path') 5000 images test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 # Classes -nc: 80 # number of classes -names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', - 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush'] # class names - +nc: 10 # number of classes +names: ['pedestrian', 'people', 'bicicle', 'car', 'van', 'truck', 'tricicle', 'awning-tricycle', 'bus', 'motor'] # Download script/URL (optional) download: | diff --git a/data/images/zidane.jpg b/data/images/zidane.jpg index 92d72ea..94e7db7 100644 Binary files a/data/images/zidane.jpg and b/data/images/zidane.jpg differ diff --git a/detect.py b/detect.py index 4549095..665e49b 100644 --- a/detect.py +++ b/detect.py @@ -12,18 +12,22 @@ import cv2 import yaml -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) from edgetpumodel import EdgeTPUModel from utils import resize_and_pad, get_image_tensor, save_one_json, coco80_to_coco91_class +#from pycoral.pybind._pywrap_coral import SetVerbosity as set_verbosity +#set_verbosity(10) + if __name__ == "__main__": parser = argparse.ArgumentParser("EdgeTPU test runner") parser.add_argument("--model", "-m", help="weights file", required=True) parser.add_argument("--bench_speed", action='store_true', help="run speed test on dummy data") parser.add_argument("--bench_image", action='store_true', help="run detection test") + parser.add_argument("--bench_airsim", action='store_true', help="run detection on airsim") parser.add_argument("--conf_thresh", type=float, default=0.25, help="model confidence threshold") parser.add_argument("--iou_thresh", type=float, default=0.45, help="NMS IOU threshold") parser.add_argument("--names", type=str, default='data/coco.yaml', help="Names file") @@ -47,7 +51,7 @@ model = EdgeTPUModel(args.model, args.names, conf_thresh=args.conf_thresh, iou_thresh=args.iou_thresh) input_size = model.get_image_size() - x = (255*np.random.random((3,*input_size))).astype(np.uint8) + x = (255*np.random.random((3,*input_size))).astype(np.int8) model.forward(x) conf_thresh = 0.25 @@ -88,6 +92,11 @@ logger.info("Testing on Zidane image") model.predict("./data/images/zidane.jpg") + elif args.bench_image: + logger.info("Testing on Zidane image") + model.predict("./data/images/zidane.jpg") + + elif args.bench_coco: logger.info("Testing on COCO dataset") diff --git a/detect_airsim.py b/detect_airsim.py new file mode 100644 index 0000000..196ac17 --- /dev/null +++ b/detect_airsim.py @@ -0,0 +1,188 @@ + +# requires Python 3.5.3 :: Anaconda 4.4.0 +# pip install opencv-python + +import os +import sys +import argparse +import logging +import time +from pathlib import Path +import glob +import json + +import numpy as np +from tqdm import tqdm +import cv2 +import yaml +import airsim + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +from edgetpumodel import EdgeTPUModel +from utils import resize_and_pad, get_image_tensor, save_one_json, coco80_to_coco91_class + +# from pycoral.pybind._pywrap_coral import SetVerbosity as set_verbosity +# set_verbosity(10) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("EdgeTPU test runner") + parser.add_argument("--model", "-m", help="weights file", required=True) + parser.add_argument("--bench_speed", action='store_true', help="run speed test on dummy data") + parser.add_argument("--bench_image", action='store_true', help="run detection test") + parser.add_argument("--bench_airsim", action='store_true', help="run detection on airsim") + parser.add_argument("--conf_thresh", type=float, default=0.25, help="model confidence threshold") + parser.add_argument("--iou_thresh", type=float, default=0.45, help="NMS IOU threshold") + parser.add_argument("--names", type=str, default='data/coco.yaml', help="Names file") + parser.add_argument("--image", "-i", type=str, help="Image file to run detection on") + parser.add_argument("--device", type=int, default=0, help="Image capture device to run live detection") + parser.add_argument("--stream", action='store_true', help="Process a stream") + parser.add_argument("--bench_coco", action='store_true', help="Process a stream") + parser.add_argument("--coco_path", type=str, help="Path to COCO 2017 Val folder") + parser.add_argument("--quiet", "-q", action='store_true', help="Disable logging (except errors)") + + args = parser.parse_args() + + if args.quiet: + logging.disable(logging.CRITICAL) + logger.disabled = True + + if args.stream and args.image: + logger.error("Please select either an input image or a stream") + exit(1) + + model = EdgeTPUModel(args.model, args.names, conf_thresh=args.conf_thresh, iou_thresh=args.iou_thresh) + input_size = model.get_image_size() + + x = (255 * np.random.random((3, *input_size))).astype(np.int8) + model.forward(x) + + conf_thresh = 0.25 + iou_thresh = 0.45 + classes = None + agnostic_nms = False + max_det = 1000 + + if args.bench_speed: + logger.info("Performing test run") + n_runs = 100 + + inference_times = [] + nms_times = [] + total_times = [] + + for i in tqdm(range(n_runs)): + x = (255 * np.random.random((3, *input_size))).astype(np.float32) + + pred = model.forward(x) + tinference, tnms = model.get_last_inference_time() + + inference_times.append(tinference) + nms_times.append(tnms) + total_times.append(tinference + tnms) + + inference_times = np.array(inference_times) + nms_times = np.array(nms_times) + total_times = np.array(total_times) + + logger.info("Inference time (EdgeTPU): {:1.2f} +- {:1.2f} ms".format(inference_times.mean() / 1e-3, + inference_times.std() / 1e-3)) + logger.info("NMS time (CPU): {:1.2f} +- {:1.2f} ms".format(nms_times.mean() / 1e-3, nms_times.std() / 1e-3)) + fps = 1.0 / total_times.mean() + logger.info("Mean FPS: {:1.2f}".format(fps)) + + elif args.bench_image: + logger.info("Testing on Zidane image") + model.predict("./data/images/zidane.jpg") + + elif args.bench_airsim: + logger.info("Testing on Zidane image") + client = airsim.MultirotorClient() + + # because this method returns std::vector, msgpack decides to encode it as a string unfortunately. + while(True): + rawImage = client.simGetImage("3", airsim.ImageType.Scene) + if (rawImage == None): + print("Camera is not returning image, please check airsim for error messages") + sys.exit(0) + else: + png = cv2.imdecode(airsim.string_to_uint8_array(rawImage), cv2.IMREAD_UNCHANGED) + cv2.imwrite("./data/images/zidane.jpg", png) + model.predict("./data/images/zidane.jpg") + + key = cv2.waitKey(1) & 0xFF + if (key == 27 or key == ord('q') or key == ord('x')): + break + + + + + + elif args.bench_coco: + logger.info("Testing on COCO dataset") + + model.conf_thresh = 0.001 + model.iou_thresh = 0.65 + + coco_glob = os.path.join(args.coco_path, "*.jpg") + images = glob.glob(coco_glob) + + logger.info("Looking for: {}".format(coco_glob)) + ids = [int(os.path.basename(i).split('.')[0]) for i in images] + + out_path = "./coco_eval" + os.makedirs("./coco_eval", exist_ok=True) + + logger.info("Found {} images".format(len(images))) + + class_map = coco80_to_coco91_class() + + predictions = [] + + for image in tqdm(images): + res = model.predict(image, save_img=False, save_txt=False) + save_one_json(res, predictions, Path(image), class_map) + + pred_json = os.path.join(out_path, + "{}_predictions.json".format(os.path.basename(args.model))) + + with open(pred_json, 'w') as f: + json.dump(predictions, f, indent=1) + + elif args.image is not None: + logger.info("Testing on user image: {}".format(args.image)) + model.predict(args.image) + + elif args.stream: + logger.info("Opening stream on device: {}".format(args.device)) + + cam = cv2.VideoCapture(args.device) + + while True: + try: + res, image = cam.read() + + if res is False: + logger.error("Empty image received") + break + else: + full_image, net_image, pad = get_image_tensor(image, input_size[0]) + pred = model.forward(net_image) + + model.process_predictions(pred[0], full_image, pad) + + tinference, tnms = model.get_last_inference_time() + logger.info("Frame done in {}".format(tinference + tnms)) + except KeyboardInterrupt: + break + + cam.release() + + + + + + + diff --git a/edgetpumodel.py b/edgetpumodel.py index 07f0425..92ac29f 100644 --- a/edgetpumodel.py +++ b/edgetpumodel.py @@ -155,7 +155,7 @@ def forward(self, x:np.ndarray, with_nms=True) -> np.ndarray: # Scale input, conversion is: real = (int_8 - zero)*scale x = (x/self.input_scale) + self.input_zero - x = x[np.newaxis].astype(np.uint8) + x = x[np.newaxis].astype(np.int8) self.interpreter.set_tensor(self.input_details[0]['index'], x) self.interpreter.invoke() @@ -261,10 +261,12 @@ def process_predictions(self, det, output_image, pad, output_path="detection.jpg output[base]['cls_name'] = self.names[c] if save_txt: - output_txt = base+"txt" + output_txt = base+".txt" with open(output_txt, 'w') as f: json.dump(output, f, indent=1) if save_img: cv2.imwrite(output_path, output_image) + cv2.imshow("Gooo!", output_image) + # cv2.waitKey(0) return det \ No newline at end of file diff --git a/nms.py b/nms.py index fb94eaf..16ca9dc 100644 --- a/nms.py +++ b/nms.py @@ -52,8 +52,21 @@ def nms(dets, scores, thresh): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300): - nc = prediction.shape[2] - 5 # number of classes - xc = prediction[..., 4] > conf_thres # candidates + nc = prediction.shape[1] - 4 # number of classes + bs = prediction.shape[0] # batch size + nm = prediction.shape[1] - nc - 4 + mi = 4 + nc # mask start index + + print(f'mi : {[mi]}') + + xc = np.amax(prediction[:, 4:mi], 1) > conf_thres # candidates + print(f'xc SHAPE: {[xc.shape]}') + # xc = prediction[..., 4] > conf_thres # candidates + + print(f'prediction SHAPE: {[prediction.shape]}') + + prediction = prediction.transpose(0,2,1) # shape(1,84,6300) to shape(1,6300,84) + print(f'prediction SHAPE: {[prediction.shape]}') # Checks assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' @@ -77,10 +90,10 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non # Cat apriori labels if autolabelling if labels and len(labels[xi]): l = labels[xi] - v = np.zeros((len(l), nc + 5)) + v = np.zeros((len(l), nc + 4)) v[:, :4] = l[:, 1:5] # box v[:, 4] = 1.0 # conf - v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + v[range(len(l)), l[:, 0].long() + 4] = 1.0 # cls x = np.concatenate((x, v), 0) # If none remain process next image @@ -88,7 +101,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non continue # Compute conf - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + # x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) @@ -98,8 +111,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(float)), axis=1) else: # best class only - conf = np.amax(x[:, 5:], axis=1, keepdims=True) - j = np.argmax(x[:, 5:], axis=1).reshape(conf.shape) + conf = np.amax(x[:, 4:], axis=1, keepdims=True) + j = np.argmax(x[:, 4:], axis=1).reshape(conf.shape) x = np.concatenate((box, conf, j.astype(float)), axis=1)[conf.flatten() > conf_thres] # Filter by class diff --git a/setup_path.py b/setup_path.py new file mode 100644 index 0000000..a33b947 --- /dev/null +++ b/setup_path.py @@ -0,0 +1,52 @@ +# Import this module to automatically setup path to local airsim module +# This module first tries to see if airsim module is installed via pip +# If it does then we don't do anything else +# Else we look up grand-parent folder to see if it has airsim folder +# and if it does then we add that in sys.path + +import os,sys,logging + +#this class simply tries to see if airsim +class SetupPath: + @staticmethod + def getDirLevels(path): + path_norm = os.path.normpath(path) + return len(path_norm.split(os.sep)) + + @staticmethod + def getCurrentPath(): + cur_filepath = __file__ + return os.path.dirname(cur_filepath) + + @staticmethod + def getGrandParentDir(): + cur_path = SetupPath.getCurrentPath() + if SetupPath.getDirLevels(cur_path) >= 2: + return os.path.dirname(os.path.dirname(cur_path)) + return '' + + @staticmethod + def getParentDir(): + cur_path = SetupPath.getCurrentPath() + if SetupPath.getDirLevels(cur_path) >= 1: + return os.path.dirname(cur_path) + return '' + + @staticmethod + def addAirSimModulePath(): + # if airsim module is installed then don't do anything else + #import pkgutil + #airsim_loader = pkgutil.find_loader('airsim') + #if airsim_loader is not None: + # return + + parent = SetupPath.getParentDir() + if parent != '': + airsim_path = os.path.join(parent, 'airsim') + client_path = os.path.join(airsim_path, 'client.py') + if os.path.exists(client_path): + sys.path.insert(0, parent) + else: + logging.warning("airsim module not found in parent folder. Using installed package (pip install airsim).") + +SetupPath.addAirSimModulePath()